Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,4 @@
#
# License: MIT License


# All submodules and packages
from . import lp
from . import bregman
from . import optim
from . import utils
from . import datasets
from . import plot
from . import da
from . import gromov

# OT functions
from .lp import emd, emd2
from .bregman import sinkhorn, sinkhorn2, barycenter
from .da import sinkhorn_lpl1_mm
from .gromov import gromov_wasserstein, gromov_wasserstein2

# utils functions
from .utils import dist, unif, tic, toc, toq

__version__ = "0.4.0"

__all__ = ["emd", "emd2", "sinkhorn", "sinkhorn2", "utils", 'datasets',
'bregman', 'lp', 'plot', 'tic', 'toc', 'toq',
'dist', 'unif', 'barycenter', 'sinkhorn_lpl1_mm', 'da', 'optim',
'gromov_wasserstein','gromov_wasserstein2']
62 changes: 32 additions & 30 deletions test/test_bregman.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
# License: MIT License

import numpy as np
import ot
from ot.utils import unif, dist, dist0
from ot.bregman import sinkhorn, barycenter, unmix
from ot.datasets import get_1D_gauss


def test_sinkhorn():
Expand All @@ -14,11 +16,11 @@ def test_sinkhorn():
rng = np.random.RandomState(0)

x = rng.randn(n, 2)
u = ot.utils.unif(n)
u = unif(n)

M = ot.dist(x, x)
M = dist(x, x)

G = ot.sinkhorn(u, u, M, 1, stopThr=1e-10)
G = sinkhorn(u, u, M, 1, stopThr=1e-10)

# check constratints
np.testing.assert_allclose(
Expand All @@ -33,22 +35,22 @@ def test_sinkhorn_empty():
rng = np.random.RandomState(0)

x = rng.randn(n, 2)
u = ot.utils.unif(n)
u = unif(n)

M = ot.dist(x, x)
M = dist(x, x)

G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True)
G, log = sinkhorn([], [], M, 1, stopThr=1e-10, verbose=True, log=True)
# check constratints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)

G, log = ot.sinkhorn([], [], M, 1, stopThr=1e-10,
method='sinkhorn_stabilized', verbose=True, log=True)
G, log = sinkhorn([], [], M, 1, stopThr=1e-10,
method='sinkhorn_stabilized', verbose=True, log=True)
# check constratints
np.testing.assert_allclose(u, G.sum(1), atol=1e-05)
np.testing.assert_allclose(u, G.sum(0), atol=1e-05)

G, log = ot.sinkhorn(
G, log = sinkhorn(
[], [], M, 1, stopThr=1e-10, method='sinkhorn_epsilon_scaling',
verbose=True, log=True)
# check constratints
Expand All @@ -62,15 +64,15 @@ def test_sinkhorn_variants():
rng = np.random.RandomState(0)

x = rng.randn(n, 2)
u = ot.utils.unif(n)
u = unif(n)

M = ot.dist(x, x)
M = dist(x, x)

G0 = ot.sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
Gs = ot.sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
Ges = ot.sinkhorn(
G0 = sinkhorn(u, u, M, 1, method='sinkhorn', stopThr=1e-10)
Gs = sinkhorn(u, u, M, 1, method='sinkhorn_stabilized', stopThr=1e-10)
Ges = sinkhorn(
u, u, M, 1, method='sinkhorn_epsilon_scaling', stopThr=1e-10)
Gerr = ot.sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10)
Gerr = sinkhorn(u, u, M, 1, method='do_not_exists', stopThr=1e-10)

# check values
np.testing.assert_allclose(G0, Gs, atol=1e-05)
Expand All @@ -83,55 +85,55 @@ def test_bary():
n_bins = 100 # nb bins

# Gaussian distributions
a1 = ot.datasets.get_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
a2 = ot.datasets.get_1D_gauss(n_bins, m=40, s=10)
a1 = get_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
a2 = get_1D_gauss(n_bins, m=40, s=10)

# creating matrix A containing all distributions
A = np.vstack((a1, a2)).T

# loss matrix + normalization
M = ot.utils.dist0(n_bins)
M = dist0(n_bins)
M /= M.max()

alpha = 0.5 # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])

# wasserstein
reg = 1e-3
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
bary_wass = barycenter(A, M, reg, weights)

np.testing.assert_allclose(1, np.sum(bary_wass))

ot.bregman.barycenter(A, M, reg, log=True, verbose=True)
barycenter(A, M, reg, log=True, verbose=True)


def test_unmix():

n_bins = 50 # nb bins

# Gaussian distributions
a1 = ot.datasets.get_1D_gauss(n_bins, m=20, s=10) # m= mean, s= std
a2 = ot.datasets.get_1D_gauss(n_bins, m=40, s=10)
a1 = get_1D_gauss(n_bins, m=20, s=10) # m= mean, s= std
a2 = get_1D_gauss(n_bins, m=40, s=10)

a = ot.datasets.get_1D_gauss(n_bins, m=30, s=10)
a = get_1D_gauss(n_bins, m=30, s=10)

# creating matrix A containing all distributions
D = np.vstack((a1, a2)).T

# loss matrix + normalization
M = ot.utils.dist0(n_bins)
M = dist0(n_bins)
M /= M.max()

M0 = ot.utils.dist0(2)
M0 = dist0(2)
M0 /= M0.max()
h0 = ot.unif(2)
h0 = unif(2)

# wasserstein
reg = 1e-3
um = ot.bregman.unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,)
um = unmix(a, D, M, M0, h0, reg, 1, alpha=0.01,)

np.testing.assert_allclose(1, np.sum(um), rtol=1e-03, atol=1e-03)
np.testing.assert_allclose([0.5, 0.5], um, rtol=1e-03, atol=1e-03)

ot.bregman.unmix(a, D, M, M0, h0, reg,
1, alpha=0.01, log=True, verbose=True)
unmix(a, D, M, M0, h0, reg,
1, alpha=0.01, log=True, verbose=True)
Loading