From cb739f625921e7fc19113d6d758e27ac69eac24b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 20 Mar 2018 15:05:57 +0100 Subject: [PATCH 01/24] add linear mapping function --- examples/plot_otda_linear_mapping.py | 54 ++++++++++++++++++++++++++++ ot/da.py | 36 +++++++++++++++++++ 2 files changed, 90 insertions(+) create mode 100644 examples/plot_otda_linear_mapping.py diff --git a/examples/plot_otda_linear_mapping.py b/examples/plot_otda_linear_mapping.py new file mode 100644 index 000000000..eff264830 --- /dev/null +++ b/examples/plot_otda_linear_mapping.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Tue Mar 20 14:31:15 2018 + +@author: rflamary +""" + +import numpy as np +import pylab as pl +import ot + + + +#%% + + +n=1000 +d=2 +sigma=.1 + +angles=np.random.rand(n,1)*2*np.pi +xs=np.concatenate((np.sin(angles),np.cos(angles)),axis=1)+sigma*np.random.randn(n,2) + +xs[:n//2,1]+=2 + +anglet=np.random.rand(n,1)*2*np.pi +xt=np.concatenate((np.sin(anglet),np.cos(anglet)),axis=1)+sigma*np.random.randn(n,2) +xt[:n//2,1]+=2 + + +A=np.array([[1.5,.7],[.7,1.5]]) +b=np.array([[4,2]]) +xt=xt.dot(A)+b + +#%% + +pl.figure(1,(5,5)) +pl.plot(xs[:,0],xs[:,1],'+') +pl.plot(xt[:,0],xt[:,1],'o') + +#%% + +Ae,be=ot.da.OT_mapping_linear(xs,xt) + +xst=xs.dot(Ae)+be + +##%% + +pl.figure(1,(5,5)) +pl.clf() +pl.plot(xs[:,0],xs[:,1],'+') +pl.plot(xt[:,0],xt[:,1],'o') +pl.plot(xst[:,0],xst[:,1],'+') \ No newline at end of file diff --git a/ot/da.py b/ot/da.py index c68865418..63bee5a4b 100644 --- a/ot/da.py +++ b/ot/da.py @@ -10,6 +10,7 @@ # License: MIT License import numpy as np +import scipy.linalg as linalg from .bregman import sinkhorn from .lp import emd @@ -633,6 +634,41 @@ def df(G): return G, L +def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,log=False): + """ return OT linear operator between samples""" + + d=xs.shape[1] + + mxs=xs.mean(0,keepdims=True) + mxt=xt.mean(0,keepdims=True) + + + if ws is None: + ws=np.ones((xs.shape[0],1))/xs.shape[0] + + if wt is None: + wt=np.ones((xt.shape[0],1))/xt.shape[0] + + Cs=(xs*ws).T.dot(xs)/ws.sum()+reg*np.eye(d) + Ct=(xt*wt).T.dot(xt)/wt.sum()+reg*np.eye(d) + + + Cs12=linalg.sqrtm(Cs) + Cs_12=linalg.inv(Cs12) + + M0=linalg.sqrtm(Cs12.dot(Ct.dot(Cs12))) + + A=Cs_12.dot(M0.dot(Cs_12)).T + + b=mxt-mxs.dot(A) + + if log: + pass + else: + return A,b + + + @deprecated("The class OTDA is deprecated in 0.3.1 and will be " "removed in 0.5" "\n\tfor standard transport use class EMDTransport instead.") From 8fc9fce6c920c646ea7324ac0af54ad53e9aa1bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 20 Mar 2018 16:21:47 +0100 Subject: [PATCH 02/24] add class LinearTransport --- examples/plot_otda_linear_mapping.py | 35 +++- ot/da.py | 239 ++++++++++++++++++++++++++- 2 files changed, 265 insertions(+), 9 deletions(-) diff --git a/examples/plot_otda_linear_mapping.py b/examples/plot_otda_linear_mapping.py index eff264830..44aa9c5ef 100644 --- a/examples/plot_otda_linear_mapping.py +++ b/examples/plot_otda_linear_mapping.py @@ -9,7 +9,7 @@ import numpy as np import pylab as pl import ot - +import scipy.linalg as linalg #%% @@ -19,11 +19,13 @@ d=2 sigma=.1 +# source samples angles=np.random.rand(n,1)*2*np.pi xs=np.concatenate((np.sin(angles),np.cos(angles)),axis=1)+sigma*np.random.randn(n,2) - xs[:n//2,1]+=2 + +# target samples anglet=np.random.rand(n,1)*2*np.pi xt=np.concatenate((np.sin(anglet),np.cos(anglet)),axis=1)+sigma*np.random.randn(n,2) xt[:n//2,1]+=2 @@ -43,7 +45,33 @@ Ae,be=ot.da.OT_mapping_linear(xs,xt) +Ae1=linalg.inv(Ae) +be1=-be.dot(Ae1) + xst=xs.dot(Ae)+be +xts=xt.dot(Ae1)+be1 + +##%% + +pl.figure(1,(5,5)) +pl.clf() +pl.plot(xs[:,0],xs[:,1],'+') +pl.plot(xt[:,0],xt[:,1],'o') +pl.plot(xst[:,0],xst[:,1],'+') +pl.plot(xts[:,0],xts[:,1],'o') + +pl.show() + + +#%% Example class with on images + +mapping=ot.da.LinearTransport() + +mapping.fit(Xs=xs,Xt=xt) + + +xst=mapping.transform(Xs=xs) +xts=mapping.inverse_transform(Xt=xt) ##%% @@ -51,4 +79,5 @@ pl.clf() pl.plot(xs[:,0],xs[:,1],'+') pl.plot(xt[:,0],xt[:,1],'o') -pl.plot(xst[:,0],xst[:,1],'+') \ No newline at end of file +pl.plot(xst[:,0],xst[:,1],'+') +pl.plot(xts[:,0],xts[:,1],'o') diff --git a/ot/da.py b/ot/da.py index 63bee5a4b..ab5f86006 100644 --- a/ot/da.py +++ b/ot/da.py @@ -634,13 +634,76 @@ def df(G): return G, L -def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,log=False): - """ return OT linear operator between samples""" +def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,bias=True,log=False): + """ return OT linear operator between samples + + The function estimate the optimal linear operator that align the two + empirical distributions. This is equivalent to estimating the closed + form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)` + and :math:`N(\mu_t,\Sigma_t)` as proposed in [14]. + + The linear operator from source to target :math:`M` + + .. math:: + M(x)=Ax+b + + where : + + .. math:: + A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2} + \Sigma_s^{-1/2} + .. math:: + b=\mu_t-A\mu_s + + Parameters + ---------- + xs : np.ndarray (ns,d) + samples in the source domain + xt : np.ndarray (nt,d) + samples in the target domain + reg : float,optional + regularization added to the daigonals of convariances (>0) + ws : np.ndarray (ns,1), optional + weights for the source samples + wt : np.ndarray (ns,1), optional + weights for the target samples + bias: boolean, optional + estimate bias b else b=0 (default:True) + log : bool, optional + record log if True + + + Returns + ------- + A : (d x d) ndarray + Linear operator + b : (1 x d) ndarray + bias + log : dict + log dictionary return only if log==True in parameters + + + References + ---------- + + .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of + distributions", Journal of Optimization Theory and Applications + Vol 43, 1984 + + + """ d=xs.shape[1] - mxs=xs.mean(0,keepdims=True) - mxt=xt.mean(0,keepdims=True) + if bias: + mxs=xs.mean(0,keepdims=True) + mxt=xt.mean(0,keepdims=True) + + xs=xs-mxs + xt=xt-mxt + else: + mxs=np.zeros((1,d)) + mxt=np.zeros((1,d)) if ws is None: @@ -658,12 +721,17 @@ def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,log=False): M0=linalg.sqrtm(Cs12.dot(Ct.dot(Cs12))) - A=Cs_12.dot(M0.dot(Cs_12)).T + A=Cs_12.dot(M0.dot(Cs_12)) b=mxt-mxs.dot(A) if log: - pass + log={} + log['Cs']=Cs + log['Ct']=Ct + log['Cs12']=Cs12 + log['Cs_12']=Cs_12 + return A,b,log else: return A,b @@ -1216,6 +1284,165 @@ class label return transp_Xt +class LinearTransport(BaseTransport): + """ OT linear operator between empirical distributions + + The function estimate the optimal linear operator that align the two + empirical distributions. This is equivalent to estimating the closed + form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)` + and :math:`N(\mu_t,\Sigma_t)` as proposed in [14]. + + The linear operator from source to target :math:`M` + + .. math:: + M(x)=Ax+b + + where : + + .. math:: + A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2} + \Sigma_s^{-1/2} + .. math:: + b=\mu_t-A\mu_s + + Parameters + ---------- + reg : float,optional + regularization added to the daigonals of convariances (>0) + bias: boolean, optional + estimate bias b else b=0 (default:True) + log : bool, optional + record log if True + + + """ + + def __init__(self, reg=1e-8,bias=True,log=False, + distribution_estimation=distribution_estimation_uniform): + + self.bias=bias + self.log=log + self.reg=reg + self.distribution_estimation=distribution_estimation + + def fit(self, Xs=None, ys=None, Xt=None, yt=None): + """Build a coupling matrix from source and target sets of samples + (Xs, ys) and (Xt, yt) + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + + Returns + ------- + self : object + Returns self. + """ + + self.mu_s = self.distribution_estimation(Xs) + self.mu_t = self.distribution_estimation(Xt) + + + + # coupling estimation + returned_ = OT_mapping_linear(Xs,Xt,reg=self.reg, + ws=self.mu_s.reshape((-1,1)), + wt=self.mu_t.reshape((-1,1)), + bias=self.bias,log=self.log) + + # deal with the value of log + if self.log: + self.A_, self.B_, self.log_ = returned_ + else: + self.A_, self.B_, = returned_ + self.log_ = dict() + + # re compute inverse mapping + self.A1_=linalg.inv(self.A_) + self.B1_=-self.B_.dot(self.A1_) + + return self + + def transform(self, Xs=None, ys=None, Xt=None, yt=None, batch_size=128): + """Transports source samples Xs onto target ones Xt + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + + Returns + ------- + transp_Xs : array-like, shape (n_source_samples, n_features) + The transport source samples. + """ + + # check the necessary inputs parameters are here + if check_params(Xs=Xs): + + transp_Xs= Xs.dot(self.A_)+self.B_ + + return transp_Xs + + def inverse_transform(self, Xs=None, ys=None, Xt=None, yt=None, + batch_size=128): + """Transports target samples Xt onto target samples Xs + + Parameters + ---------- + Xs : array-like, shape (n_source_samples, n_features) + The training input samples. + ys : array-like, shape (n_source_samples,) + The class labels + Xt : array-like, shape (n_target_samples, n_features) + The training input samples. + yt : array-like, shape (n_target_samples,) + The class labels. If some target samples are unlabeled, fill the + yt's elements with -1. + + Warning: Note that, due to this convention -1 cannot be used as a + class label + batch_size : int, optional (default=128) + The batch size for out of sample inverse transform + + Returns + ------- + transp_Xt : array-like, shape (n_source_samples, n_features) + The transported target samples. + """ + + # check the necessary inputs parameters are here + if check_params(Xt=Xt): + + transp_Xt= Xt.dot(self.A1_)+self.B1_ + + return transp_Xt + + + class SinkhornTransport(BaseTransport): """Domain Adapatation OT method based on Sinkhorn Algorithm From c1046238d826fe9cf1294f8ea60b8d44743fac78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 20 Mar 2018 16:27:49 +0100 Subject: [PATCH 03/24] passing tests --- examples/plot_otda_linear_mapping.py | 74 +++++++------- ot/da.py | 147 +++++++++++++-------------- 2 files changed, 110 insertions(+), 111 deletions(-) diff --git a/examples/plot_otda_linear_mapping.py b/examples/plot_otda_linear_mapping.py index 44aa9c5ef..143f129e6 100644 --- a/examples/plot_otda_linear_mapping.py +++ b/examples/plot_otda_linear_mapping.py @@ -15,69 +15,71 @@ #%% -n=1000 -d=2 -sigma=.1 +n = 1000 +d = 2 +sigma = .1 # source samples -angles=np.random.rand(n,1)*2*np.pi -xs=np.concatenate((np.sin(angles),np.cos(angles)),axis=1)+sigma*np.random.randn(n,2) -xs[:n//2,1]+=2 +angles = np.random.rand(n, 1) * 2 * np.pi +xs = np.concatenate((np.sin(angles), np.cos(angles)), + axis=1) + sigma * np.random.randn(n, 2) +xs[:n // 2, 1] += 2 # target samples -anglet=np.random.rand(n,1)*2*np.pi -xt=np.concatenate((np.sin(anglet),np.cos(anglet)),axis=1)+sigma*np.random.randn(n,2) -xt[:n//2,1]+=2 +anglet = np.random.rand(n, 1) * 2 * np.pi +xt = np.concatenate((np.sin(anglet), np.cos(anglet)), + axis=1) + sigma * np.random.randn(n, 2) +xt[:n // 2, 1] += 2 -A=np.array([[1.5,.7],[.7,1.5]]) -b=np.array([[4,2]]) -xt=xt.dot(A)+b +A = np.array([[1.5, .7], [.7, 1.5]]) +b = np.array([[4, 2]]) +xt = xt.dot(A) + b #%% -pl.figure(1,(5,5)) -pl.plot(xs[:,0],xs[:,1],'+') -pl.plot(xt[:,0],xt[:,1],'o') +pl.figure(1, (5, 5)) +pl.plot(xs[:, 0], xs[:, 1], '+') +pl.plot(xt[:, 0], xt[:, 1], 'o') #%% -Ae,be=ot.da.OT_mapping_linear(xs,xt) +Ae, be = ot.da.OT_mapping_linear(xs, xt) -Ae1=linalg.inv(Ae) -be1=-be.dot(Ae1) +Ae1 = linalg.inv(Ae) +be1 = -be.dot(Ae1) -xst=xs.dot(Ae)+be -xts=xt.dot(Ae1)+be1 +xst = xs.dot(Ae) + be +xts = xt.dot(Ae1) + be1 -##%% +# %% -pl.figure(1,(5,5)) +pl.figure(1, (5, 5)) pl.clf() -pl.plot(xs[:,0],xs[:,1],'+') -pl.plot(xt[:,0],xt[:,1],'o') -pl.plot(xst[:,0],xst[:,1],'+') -pl.plot(xts[:,0],xts[:,1],'o') +pl.plot(xs[:, 0], xs[:, 1], '+') +pl.plot(xt[:, 0], xt[:, 1], 'o') +pl.plot(xst[:, 0], xst[:, 1], '+') +pl.plot(xts[:, 0], xts[:, 1], 'o') pl.show() #%% Example class with on images -mapping=ot.da.LinearTransport() +mapping = ot.da.LinearTransport() -mapping.fit(Xs=xs,Xt=xt) +mapping.fit(Xs=xs, Xt=xt) -xst=mapping.transform(Xs=xs) -xts=mapping.inverse_transform(Xt=xt) +xst = mapping.transform(Xs=xs) +xts = mapping.inverse_transform(Xt=xt) -##%% +# %% -pl.figure(1,(5,5)) +pl.figure(1, (5, 5)) pl.clf() -pl.plot(xs[:,0],xs[:,1],'+') -pl.plot(xt[:,0],xt[:,1],'o') -pl.plot(xst[:,0],xst[:,1],'+') -pl.plot(xts[:,0],xts[:,1],'o') +pl.plot(xs[:, 0], xs[:, 1], '+') +pl.plot(xt[:, 0], xt[:, 1], 'o') +pl.plot(xst[:, 0], xst[:, 1], '+') +pl.plot(xts[:, 0], xts[:, 1], 'o') diff --git a/ot/da.py b/ot/da.py index ab5f86006..f789396a6 100644 --- a/ot/da.py +++ b/ot/da.py @@ -357,7 +357,8 @@ def sel(x): def loss(L, G): """Compute full loss""" - return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * np.sum(G * M) + eta * np.sum(sel(L - I0)**2) + return np.sum((xs1.dot(L) - ns * G.dot(xt))**2) + mu * \ + np.sum(G * M) + eta * np.sum(sel(L - I0)**2) def solve_L(G): """ solve L problem with fixed G (least square)""" @@ -557,7 +558,8 @@ def joint_OT_mapping_kernel(xs, xt, mu=1, eta=0.001, kerneltype='gaussian', def loss(L, G): """Compute full loss""" - return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L)) + return np.sum((K1.dot(L) - ns * G.dot(xt))**2) + mu * \ + np.sum(G * M) + eta * np.trace(L.T.dot(Kreg).dot(L)) def solve_L_nobias(G): """ solve L problem with fixed G (least square)""" @@ -634,25 +636,26 @@ def df(G): return G, L -def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,bias=True,log=False): +def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, + wt=None, bias=True, log=False): """ return OT linear operator between samples The function estimate the optimal linear operator that align the two - empirical distributions. This is equivalent to estimating the closed - form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)` + empirical distributions. This is equivalent to estimating the closed + form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)` and :math:`N(\mu_t,\Sigma_t)` as proposed in [14]. - + The linear operator from source to target :math:`M` .. math:: M(x)=Ax+b - + where : - + .. math:: A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2} \Sigma_s^{-1/2} - .. math:: + .. math:: b=\mu_t-A\mu_s Parameters @@ -666,7 +669,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,bias=True,log=False): ws : np.ndarray (ns,1), optional weights for the source samples wt : np.ndarray (ns,1), optional - weights for the target samples + weights for the target samples bias: boolean, optional estimate bias b else b=0 (default:True) log : bool, optional @@ -686,55 +689,52 @@ def OT_mapping_linear(xs, xt, reg=1e-6,ws=None,wt=None,bias=True,log=False): References ---------- - .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of + .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of distributions", Journal of Optimization Theory and Applications Vol 43, 1984 """ - d=xs.shape[1] - + d = xs.shape[1] + if bias: - mxs=xs.mean(0,keepdims=True) - mxt=xt.mean(0,keepdims=True) - - xs=xs-mxs - xt=xt-mxt + mxs = xs.mean(0, keepdims=True) + mxt = xt.mean(0, keepdims=True) + + xs = xs - mxs + xt = xt - mxt else: - mxs=np.zeros((1,d)) - mxt=np.zeros((1,d)) + mxs = np.zeros((1, d)) + mxt = np.zeros((1, d)) - if ws is None: - ws=np.ones((xs.shape[0],1))/xs.shape[0] - + ws = np.ones((xs.shape[0], 1)) / xs.shape[0] + if wt is None: - wt=np.ones((xt.shape[0],1))/xt.shape[0] - - Cs=(xs*ws).T.dot(xs)/ws.sum()+reg*np.eye(d) - Ct=(xt*wt).T.dot(xt)/wt.sum()+reg*np.eye(d) - - - Cs12=linalg.sqrtm(Cs) - Cs_12=linalg.inv(Cs12) - - M0=linalg.sqrtm(Cs12.dot(Ct.dot(Cs12))) - - A=Cs_12.dot(M0.dot(Cs_12)) - - b=mxt-mxs.dot(A) - + wt = np.ones((xt.shape[0], 1)) / xt.shape[0] + + Cs = (xs * ws).T.dot(xs) / ws.sum() + reg * np.eye(d) + Ct = (xt * wt).T.dot(xt) / wt.sum() + reg * np.eye(d) + + Cs12 = linalg.sqrtm(Cs) + Cs_12 = linalg.inv(Cs12) + + M0 = linalg.sqrtm(Cs12.dot(Ct.dot(Cs12))) + + A = Cs_12.dot(M0.dot(Cs_12)) + + b = mxt - mxs.dot(A) + if log: - log={} - log['Cs']=Cs - log['Ct']=Ct - log['Cs12']=Cs12 - log['Cs_12']=Cs_12 - return A,b,log + log = {} + log['Cs'] = Cs + log['Ct'] = Ct + log['Cs12'] = Cs12 + log['Cs_12'] = Cs_12 + return A, b, log else: - return A,b - + return A, b @deprecated("The class OTDA is deprecated in 0.3.1 and will be " @@ -1288,42 +1288,42 @@ class LinearTransport(BaseTransport): """ OT linear operator between empirical distributions The function estimate the optimal linear operator that align the two - empirical distributions. This is equivalent to estimating the closed - form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)` + empirical distributions. This is equivalent to estimating the closed + form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)` and :math:`N(\mu_t,\Sigma_t)` as proposed in [14]. - + The linear operator from source to target :math:`M` .. math:: M(x)=Ax+b - + where : - + .. math:: A=\Sigma_s^{-1/2}(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2})^{1/2} \Sigma_s^{-1/2} - .. math:: + .. math:: b=\mu_t-A\mu_s Parameters ---------- reg : float,optional - regularization added to the daigonals of convariances (>0) + regularization added to the daigonals of convariances (>0) bias: boolean, optional estimate bias b else b=0 (default:True) log : bool, optional record log if True - - + + """ - - def __init__(self, reg=1e-8,bias=True,log=False, + + def __init__(self, reg=1e-8, bias=True, log=False, distribution_estimation=distribution_estimation_uniform): - - self.bias=bias - self.log=log - self.reg=reg - self.distribution_estimation=distribution_estimation + + self.bias = bias + self.log = log + self.reg = reg + self.distribution_estimation = distribution_estimation def fit(self, Xs=None, ys=None, Xt=None, yt=None): """Build a coupling matrix from source and target sets of samples @@ -1349,17 +1349,15 @@ class label self : object Returns self. """ - + self.mu_s = self.distribution_estimation(Xs) self.mu_t = self.distribution_estimation(Xt) - - # coupling estimation - returned_ = OT_mapping_linear(Xs,Xt,reg=self.reg, - ws=self.mu_s.reshape((-1,1)), - wt=self.mu_t.reshape((-1,1)), - bias=self.bias,log=self.log) + returned_ = OT_mapping_linear(Xs, Xt, reg=self.reg, + ws=self.mu_s.reshape((-1, 1)), + wt=self.mu_t.reshape((-1, 1)), + bias=self.bias, log=self.log) # deal with the value of log if self.log: @@ -1367,10 +1365,10 @@ class label else: self.A_, self.B_, = returned_ self.log_ = dict() - + # re compute inverse mapping - self.A1_=linalg.inv(self.A_) - self.B1_=-self.B_.dot(self.A1_) + self.A1_ = linalg.inv(self.A_) + self.B1_ = -self.B_.dot(self.A1_) return self @@ -1403,7 +1401,7 @@ class label # check the necessary inputs parameters are here if check_params(Xs=Xs): - transp_Xs= Xs.dot(self.A_)+self.B_ + transp_Xs = Xs.dot(self.A_) + self.B_ return transp_Xs @@ -1437,12 +1435,11 @@ class label # check the necessary inputs parameters are here if check_params(Xt=Xt): - transp_Xt= Xt.dot(self.A1_)+self.B1_ + transp_Xt = Xt.dot(self.A1_) + self.B1_ return transp_Xt - class SinkhornTransport(BaseTransport): """Domain Adapatation OT method based on Sinkhorn Algorithm From 4fc9ccc7c6c96a43c48be54e89133c8f481d8bf4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 20 Mar 2018 16:39:24 +0100 Subject: [PATCH 04/24] better example+test --- Makefile | 2 +- examples/plot_otda_linear_mapping.py | 98 +++++++++++++++++++++------- 2 files changed, 77 insertions(+), 23 deletions(-) diff --git a/Makefile b/Makefile index 3f19e8ac8..d334163bd 100644 --- a/Makefile +++ b/Makefile @@ -41,7 +41,7 @@ pep8 : flake8 examples/ ot/ test/ test : FORCE pep8 - python -m py.test -v test/ --cov=ot --cov-report html:cov_html + python3 -m pytest -v test/ --cov=ot --cov-report html:cov_html pytest : FORCE python -m py.test -v test/ --cov=ot diff --git a/examples/plot_otda_linear_mapping.py b/examples/plot_otda_linear_mapping.py index 143f129e6..163f8f197 100644 --- a/examples/plot_otda_linear_mapping.py +++ b/examples/plot_otda_linear_mapping.py @@ -9,11 +9,11 @@ import numpy as np import pylab as pl import ot -import scipy.linalg as linalg - - -#%% +from scipy import ndimage +############################################################################## +# Generate data +# ------------- n = 1000 d = 2 @@ -37,49 +37,103 @@ b = np.array([[4, 2]]) xt = xt.dot(A) + b -#%% +############################################################################## +# Plot data +# --------- pl.figure(1, (5, 5)) pl.plot(xs[:, 0], xs[:, 1], '+') pl.plot(xt[:, 0], xt[:, 1], 'o') -#%% -Ae, be = ot.da.OT_mapping_linear(xs, xt) +############################################################################## +# Estimate linear mapping and transport +# ------------------------------------- -Ae1 = linalg.inv(Ae) -be1 = -be.dot(Ae1) +Ae, be = ot.da.OT_mapping_linear(xs, xt) xst = xs.dot(Ae) + be -xts = xt.dot(Ae1) + be1 -# %% + +############################################################################## +# Plot transported samples +# ------------------------ pl.figure(1, (5, 5)) pl.clf() pl.plot(xs[:, 0], xs[:, 1], '+') pl.plot(xt[:, 0], xt[:, 1], 'o') pl.plot(xst[:, 0], xst[:, 1], '+') -pl.plot(xts[:, 0], xts[:, 1], 'o') pl.show() +############################################################################## +# Mapping Class between images +# ---------------------------- + + +def im2mat(I): + """Converts and image to matrix (one pixel per line)""" + return I.reshape((I.shape[0] * I.shape[1], I.shape[2])) + + +def mat2im(X, shape): + """Converts back a matrix to an image""" + return X.reshape(shape) + + +def minmax(I): + return np.clip(I, 0, 1) + + +# Loading images +I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256 +I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256 -#%% Example class with on images + +X1 = im2mat(I1) +X2 = im2mat(I2) + +############################################################################## +# Estimate mapping and adapt +# ---------------------------- mapping = ot.da.LinearTransport() -mapping.fit(Xs=xs, Xt=xt) +mapping.fit(Xs=X1, Xt=X2) -xst = mapping.transform(Xs=xs) -xts = mapping.inverse_transform(Xt=xt) +xst = mapping.transform(Xs=X1) +xts = mapping.inverse_transform(Xt=X2) + +I1t = minmax(mat2im(xst, I1.shape)) +I2t = minmax(mat2im(xts, I2.shape)) # %% -pl.figure(1, (5, 5)) -pl.clf() -pl.plot(xs[:, 0], xs[:, 1], '+') -pl.plot(xt[:, 0], xt[:, 1], 'o') -pl.plot(xst[:, 0], xst[:, 1], '+') -pl.plot(xts[:, 0], xts[:, 1], 'o') + +############################################################################## +# Plot transformed images +# ----------------------- + +pl.figure(2, figsize=(10, 7)) + +pl.subplot(2, 2, 1) +pl.imshow(I1) +pl.axis('off') +pl.title('Im. 1') + +pl.subplot(2, 2, 2) +pl.imshow(I2) +pl.axis('off') +pl.title('Im. 2') + +pl.subplot(2, 2, 3) +pl.imshow(I1t) +pl.axis('off') +pl.title('Mapping Im. 1') + +pl.subplot(2, 2, 4) +pl.imshow(I2t) +pl.axis('off') +pl.title('Inverse mapping Im. 2') From 88a81c37d2f8c5419f88ccac620186e583fab08f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 20 Mar 2018 16:49:31 +0100 Subject: [PATCH 05/24] makefile update --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index d334163bd..7e0c576c4 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,6 @@ -PYTHON=python +PYTHON=python3 help : @echo "The following make targets are available:" @@ -41,7 +41,7 @@ pep8 : flake8 examples/ ot/ test/ test : FORCE pep8 - python3 -m pytest -v test/ --cov=ot --cov-report html:cov_html + $(PYTHON) -m pytest -v test/ --cov=ot --cov-report html:cov_html pytest : FORCE python -m py.test -v test/ --cov=ot From 287c659ad35f5036ba2687caf73009ef455c7239 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 20 Mar 2018 16:57:35 +0100 Subject: [PATCH 06/24] update example --- examples/plot_otda_linear_mapping.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/plot_otda_linear_mapping.py b/examples/plot_otda_linear_mapping.py index 163f8f197..165fe72d8 100644 --- a/examples/plot_otda_linear_mapping.py +++ b/examples/plot_otda_linear_mapping.py @@ -68,8 +68,8 @@ pl.show() ############################################################################## -# Mapping Class between images -# ---------------------------- +# Load image data +# --------------- def im2mat(I): From 6fdf5de8fa27fa16d6b8910fe96eb67b7761aa0e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 21 Mar 2018 08:29:50 +0100 Subject: [PATCH 07/24] add linear mapping test + autopep8 --- examples/plot_otda_linear_mapping.py | 5 ++--- ot/bregman.py | 30 ++++++++++++++++++---------- ot/lp/__init__.py | 3 ++- ot/optim.py | 9 ++++++--- ot/utils.py | 2 +- test/test_da.py | 18 +++++++++++++++++ 6 files changed, 49 insertions(+), 18 deletions(-) diff --git a/examples/plot_otda_linear_mapping.py b/examples/plot_otda_linear_mapping.py index 165fe72d8..7a3b76154 100644 --- a/examples/plot_otda_linear_mapping.py +++ b/examples/plot_otda_linear_mapping.py @@ -9,7 +9,6 @@ import numpy as np import pylab as pl import ot -from scipy import ndimage ############################################################################## # Generate data @@ -87,8 +86,8 @@ def minmax(I): # Loading images -I1 = ndimage.imread('../data/ocean_day.jpg').astype(np.float64) / 256 -I2 = ndimage.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256 +I1 = pl.imread('../data/ocean_day.jpg').astype(np.float64) / 256 +I2 = pl.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256 X1 = im2mat(I1) diff --git a/ot/bregman.py b/ot/bregman.py index d63c51de1..07b8660e3 100644 --- a/ot/bregman.py +++ b/ot/bregman.py @@ -11,7 +11,8 @@ import numpy as np -def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): u""" Solve the entropic regularization optimal transport problem and return the OT matrix @@ -120,7 +121,8 @@ def sink(): return sink() -def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): u""" Solve the entropic regularization optimal transport problem and return the loss @@ -233,7 +235,8 @@ def sink(): return sink() -def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, log=False, **kwargs): +def sinkhorn_knopp(a, b, M, reg, numItermax=1000, + stopThr=1e-9, verbose=False, log=False, **kwargs): """ Solve the entropic regularization optimal transport problem and return the OT matrix @@ -403,7 +406,8 @@ def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False, l return u.reshape((-1, 1)) * K * v.reshape((1, -1)) -def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=20, log=False, **kwargs): +def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, + warmstart=None, verbose=False, print_period=20, log=False, **kwargs): """ Solve the entropic regularization OT problem with log stabilization @@ -526,11 +530,13 @@ def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9, wa def get_K(alpha, beta): """log space computation""" - return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg) + return np.exp(-(M - alpha.reshape((na, 1)) - + beta.reshape((1, nb))) / reg) def get_Gamma(alpha, beta, u, v): """log space gamma computation""" - return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb)))) + return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / + reg + np.log(u.reshape((na, 1))) + np.log(v.reshape((1, nb)))) # print(np.min(K)) @@ -620,7 +626,8 @@ def get_Gamma(alpha, beta, u, v): return get_Gamma(alpha, beta, u, v) -def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs): +def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInnerItermax=100, + tau=1e3, stopThr=1e-9, warmstart=None, verbose=False, print_period=10, log=False, **kwargs): """ Solve the entropic regularization optimal transport problem with log stabilization and epsilon scaling. @@ -739,7 +746,8 @@ def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4, numInne def get_K(alpha, beta): """log space computation""" - return np.exp(-(M - alpha.reshape((na, 1)) - beta.reshape((1, nb))) / reg) + return np.exp(-(M - alpha.reshape((na, 1)) - + beta.reshape((1, nb))) / reg) # print(np.min(K)) def get_reg(n): # exponential decreasing @@ -811,7 +819,8 @@ def projC(gamma, q): return np.multiply(gamma, q / np.maximum(np.sum(gamma, axis=0), 1e-10)) -def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=False, log=False): +def barycenter(A, M, reg, weights=None, numItermax=1000, + stopThr=1e-4, verbose=False, log=False): """Compute the entropic regularized wasserstein barycenter of distributions A The function solves the following optimization problem: @@ -904,7 +913,8 @@ def barycenter(A, M, reg, weights=None, numItermax=1000, stopThr=1e-4, verbose=F return geometricBar(weights, UKv) -def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, stopThr=1e-3, verbose=False, log=False): +def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000, + stopThr=1e-3, verbose=False, log=False): """ Compute the unmixing of an observation with a given dictionary using Wasserstein distance diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 5c09da213..6371feba1 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -107,7 +107,8 @@ def emd(a, b, M, numItermax=100000, log=False): return G -def emd2(a, b, M, processes=multiprocessing.cpu_count(), numItermax=100000, log=False, return_matrix=False): +def emd2(a, b, M, processes=multiprocessing.cpu_count(), + numItermax=100000, log=False, return_matrix=False): """Solves the Earth Movers distance problem and returns the loss .. math:: diff --git a/ot/optim.py b/ot/optim.py index 1d09adcf0..f31fae2d1 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -15,7 +15,8 @@ # The corresponding scipy function does not work for matrices -def line_search_armijo(f, xk, pk, gfk, old_fval, args=(), c1=1e-4, alpha0=0.99): +def line_search_armijo(f, xk, pk, gfk, old_fval, + args=(), c1=1e-4, alpha0=0.99): """ Armijo linesearch function that works with matrices @@ -71,7 +72,8 @@ def phi(alpha1): return alpha, fc[0], phi1 -def cg(a, b, M, reg, f, df, G0=None, numItermax=200, stopThr=1e-9, verbose=False, log=False): +def cg(a, b, M, reg, f, df, G0=None, numItermax=200, + stopThr=1e-9, verbose=False, log=False): """ Solve the general regularized OT problem with conditional gradient @@ -202,7 +204,8 @@ def cost(G): return G -def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax=200, stopThr=1e-9, verbose=False, log=False): +def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, + numInnerItermax=200, stopThr=1e-9, verbose=False, log=False): """ Solve the general regularized OT problem with the generalized conditional gradient diff --git a/ot/utils.py b/ot/utils.py index 9eab3fcac..16862ead8 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -316,7 +316,7 @@ def _is_deprecated(func): closures = [] is_deprecated = ('deprecated' in ''.join([c.cell_contents for c in closures - if isinstance(c.cell_contents, str)])) + if isinstance(c.cell_contents, str)])) return is_deprecated diff --git a/test/test_da.py b/test/test_da.py index 593dc537b..7b63daf6f 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -444,6 +444,24 @@ def test_mapping_transport_class(): assert len(otda.log_.keys()) != 0 +def test_linear_mapping(): + + ns = 150 + nt = 200 + + Xs, ys = get_data_classif('3gauss', ns) + Xt, yt = get_data_classif('3gauss2', nt) + + A, b = ot.da.OT_mapping_linear(Xs, Xt) + + Xst = Xs.dot(A) + b + + Ct = np.cov(Xt.T) + Cst = np.cov(Xst.T) + + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + def test_otda(): n_samples = 150 # nb samples From 5efdf008865ea347775708b637d933e048d663ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 21 Mar 2018 09:03:58 +0100 Subject: [PATCH 08/24] add test linear mapping class --- Makefile | 5 +++++ test/test_da.py | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/Makefile b/Makefile index 7e0c576c4..468d042c0 100644 --- a/Makefile +++ b/Makefile @@ -56,6 +56,11 @@ rdoc : notebook : ipython notebook --matplotlib=inline --notebook-dir=notebooks/ + +autopep8 : + autopep8 -ir test ot examples +aautopep8 : + autopep8 -air test ot examples FORCE : diff --git a/test/test_da.py b/test/test_da.py index 7b63daf6f..a9d6d349c 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -462,6 +462,30 @@ def test_linear_mapping(): np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) +def test_linear_mapping_class(): + + ns = 150 + nt = 200 + + Xs, ys = get_data_classif('3gauss', ns) + Xt, yt = get_data_classif('3gauss2', nt) + + otmap = ot.da.LinearTransport() + + otmap.fit(Xs=Xs, Xt=Xt) + assert hasattr(otmap, "A_") + assert hasattr(otmap, "B_") + assert hasattr(otmap, "A1_") + assert hasattr(otmap, "B1_") + + Xst = otmap.transform(Xs=Xs) + + Ct = np.cov(Xt.T) + Cst = np.cov(Xst.T) + + np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) + + def test_otda(): n_samples = 150 # nb samples From fc9923dea2706b65ffe15fc86428cd8b53b5feb1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 21 Mar 2018 09:33:57 +0100 Subject: [PATCH 09/24] add tests for ot.uils --- test/test_utils.py | 77 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/test/test_utils.py b/test/test_utils.py index 1bd37cdc5..b524ef6fe 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -7,6 +7,7 @@ import ot import numpy as np +import sys def test_parmap(): @@ -123,3 +124,79 @@ def test_clean_zeros(): assert len(a) == n - nz assert len(b) == n - nz2 + + +def test_cost_normalization(): + + C = np.random.rand(10, 10) + + # does nothing + M0 = ot.utils.cost_normalization(C) + np.testing.assert_allclose(C, M0) + + M = ot.utils.cost_normalization(C, 'median') + np.testing.assert_allclose(np.median(M), 1) + + M = ot.utils.cost_normalization(C, 'max') + np.testing.assert_allclose(M.max(), 1) + + M = ot.utils.cost_normalization(C, 'log') + np.testing.assert_allclose(M.max(), np.log(1 + C).max()) + + M = ot.utils.cost_normalization(C, 'loglog') + np.testing.assert_allclose(M.max(), np.log(1 + np.log(1 + C)).max()) + + +def test_check_params(): + + res1 = ot.utils.check_params(first='OK', second=20) + assert res1 is True + + res0 = ot.utils.check_params(first='OK', second=None) + assert res0 is False + + +def test_deprecated_func(): + + @ot.utils.deprecated('deprecated text for fun') + def fun(): + pass + + def fun2(): + pass + + @ot.utils.deprecated('deprecated text for class') + class Class(): + pass + + if sys.version_info < (3, 5): + print('Not tested') + else: + assert ot.utils._is_deprecated(fun) is True + + assert ot.utils._is_deprecated(fun2) is False + + +def test_BaseEstimator(): + + class Class(ot.utils.BaseEstimator): + + def __init__(self, first='spam', second='eggs'): + + self.first = first + self.second = second + + cl = Class() + + names = cl._get_param_names() + assert 'first' in names + assert 'second' in names + + params = cl.get_params() + assert 'first' in params + assert 'second' in params + + params['first'] = 'spam again' + cl.set_params(**params) + + assert cl.first == 'spam again' From 927395b40dae98bcf027b601b6df48a4318cfef2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 21 Mar 2018 10:07:47 +0100 Subject: [PATCH 10/24] add externals for function signature --- ot/externals/__init__.py | 0 ot/externals/funcsigs.py | 815 +++++++++++++++++++++++++++++++++++++++ ot/gromov.py | 2 +- ot/utils.py | 10 +- 4 files changed, 821 insertions(+), 6 deletions(-) create mode 100644 ot/externals/__init__.py create mode 100644 ot/externals/funcsigs.py diff --git a/ot/externals/__init__.py b/ot/externals/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/ot/externals/funcsigs.py b/ot/externals/funcsigs.py new file mode 100644 index 000000000..4e684690b --- /dev/null +++ b/ot/externals/funcsigs.py @@ -0,0 +1,815 @@ +# Copyright 2001-2013 Python Software Foundation; All Rights Reserved +"""Function signature objects for callables + +Back port of Python 3.3's function signature tools from the inspect module, +modified to be compatible with Python 2.7 and 3.2+. +""" +from __future__ import absolute_import, division, print_function +import itertools +import functools +import re +import types + +from collections import OrderedDict + +__version__ = "0.4" + +__all__ = ['BoundArguments', 'Parameter', 'Signature', 'signature'] + + +_WrapperDescriptor = type(type.__call__) +_MethodWrapper = type(all.__call__) + +_NonUserDefinedCallables = (_WrapperDescriptor, + _MethodWrapper, + types.BuiltinFunctionType) + + +def formatannotation(annotation, base_module=None): + if isinstance(annotation, type): + if annotation.__module__ in ('builtins', '__builtin__', base_module): + return annotation.__name__ + return annotation.__module__+'.'+annotation.__name__ + return repr(annotation) + + +def _get_user_defined_method(cls, method_name, *nested): + try: + if cls is type: + return + meth = getattr(cls, method_name) + for name in nested: + meth = getattr(meth, name, meth) + except AttributeError: + return + else: + if not isinstance(meth, _NonUserDefinedCallables): + # Once '__signature__' will be added to 'C'-level + # callables, this check won't be necessary + return meth + + +def signature(obj): + '''Get a signature object for the passed callable.''' + + if not callable(obj): + raise TypeError('{0!r} is not a callable object'.format(obj)) + + if isinstance(obj, types.MethodType): + sig = signature(obj.__func__) + if obj.__self__ is None: + # Unbound method: the first parameter becomes positional-only + if sig.parameters: + first = sig.parameters.values()[0].replace( + kind=_POSITIONAL_ONLY) + return sig.replace( + parameters=(first,) + tuple(sig.parameters.values())[1:]) + else: + return sig + else: + # In this case we skip the first parameter of the underlying + # function (usually `self` or `cls`). + return sig.replace(parameters=tuple(sig.parameters.values())[1:]) + + try: + sig = obj.__signature__ + except AttributeError: + pass + else: + if sig is not None: + return sig + + try: + # Was this function wrapped by a decorator? + wrapped = obj.__wrapped__ + except AttributeError: + pass + else: + return signature(wrapped) + + if isinstance(obj, types.FunctionType): + return Signature.from_function(obj) + + if isinstance(obj, functools.partial): + sig = signature(obj.func) + + new_params = OrderedDict(sig.parameters.items()) + + partial_args = obj.args or () + partial_keywords = obj.keywords or {} + try: + ba = sig.bind_partial(*partial_args, **partial_keywords) + except TypeError as ex: + msg = 'partial object {0!r} has incorrect arguments'.format(obj) + raise ValueError(msg) + + for arg_name, arg_value in ba.arguments.items(): + param = new_params[arg_name] + if arg_name in partial_keywords: + # We set a new default value, because the following code + # is correct: + # + # >>> def foo(a): print(a) + # >>> print(partial(partial(foo, a=10), a=20)()) + # 20 + # >>> print(partial(partial(foo, a=10), a=20)(a=30)) + # 30 + # + # So, with 'partial' objects, passing a keyword argument is + # like setting a new default value for the corresponding + # parameter + # + # We also mark this parameter with '_partial_kwarg' + # flag. Later, in '_bind', the 'default' value of this + # parameter will be added to 'kwargs', to simulate + # the 'functools.partial' real call. + new_params[arg_name] = param.replace(default=arg_value, + _partial_kwarg=True) + + elif (param.kind not in (_VAR_KEYWORD, _VAR_POSITIONAL) and + not param._partial_kwarg): + new_params.pop(arg_name) + + return sig.replace(parameters=new_params.values()) + + sig = None + if isinstance(obj, type): + # obj is a class or a metaclass + + # First, let's see if it has an overloaded __call__ defined + # in its metaclass + call = _get_user_defined_method(type(obj), '__call__') + if call is not None: + sig = signature(call) + else: + # Now we check if the 'obj' class has a '__new__' method + new = _get_user_defined_method(obj, '__new__') + if new is not None: + sig = signature(new) + else: + # Finally, we should have at least __init__ implemented + init = _get_user_defined_method(obj, '__init__') + if init is not None: + sig = signature(init) + elif not isinstance(obj, _NonUserDefinedCallables): + # An object with __call__ + # We also check that the 'obj' is not an instance of + # _WrapperDescriptor or _MethodWrapper to avoid + # infinite recursion (and even potential segfault) + call = _get_user_defined_method(type(obj), '__call__', 'im_func') + if call is not None: + sig = signature(call) + + if sig is not None: + # For classes and objects we skip the first parameter of their + # __call__, __new__, or __init__ methods + return sig.replace(parameters=tuple(sig.parameters.values())[1:]) + + if isinstance(obj, types.BuiltinFunctionType): + # Raise a nicer error message for builtins + msg = 'no signature found for builtin function {0!r}'.format(obj) + raise ValueError(msg) + + raise ValueError('callable {0!r} is not supported by signature'.format(obj)) + + +class _void(object): + '''A private marker - used in Parameter & Signature''' + + +class _empty(object): + pass + + +class _ParameterKind(int): + def __new__(self, *args, **kwargs): + obj = int.__new__(self, *args) + obj._name = kwargs['name'] + return obj + + def __str__(self): + return self._name + + def __repr__(self): + return '<_ParameterKind: {0!r}>'.format(self._name) + + +_POSITIONAL_ONLY = _ParameterKind(0, name='POSITIONAL_ONLY') +_POSITIONAL_OR_KEYWORD = _ParameterKind(1, name='POSITIONAL_OR_KEYWORD') +_VAR_POSITIONAL = _ParameterKind(2, name='VAR_POSITIONAL') +_KEYWORD_ONLY = _ParameterKind(3, name='KEYWORD_ONLY') +_VAR_KEYWORD = _ParameterKind(4, name='VAR_KEYWORD') + + +class Parameter(object): + '''Represents a parameter in a function signature. + + Has the following public attributes: + + * name : str + The name of the parameter as a string. + * default : object + The default value for the parameter if specified. If the + parameter has no default value, this attribute is not set. + * annotation + The annotation for the parameter if specified. If the + parameter has no annotation, this attribute is not set. + * kind : str + Describes how argument values are bound to the parameter. + Possible values: `Parameter.POSITIONAL_ONLY`, + `Parameter.POSITIONAL_OR_KEYWORD`, `Parameter.VAR_POSITIONAL`, + `Parameter.KEYWORD_ONLY`, `Parameter.VAR_KEYWORD`. + ''' + + __slots__ = ('_name', '_kind', '_default', '_annotation', '_partial_kwarg') + + POSITIONAL_ONLY = _POSITIONAL_ONLY + POSITIONAL_OR_KEYWORD = _POSITIONAL_OR_KEYWORD + VAR_POSITIONAL = _VAR_POSITIONAL + KEYWORD_ONLY = _KEYWORD_ONLY + VAR_KEYWORD = _VAR_KEYWORD + + empty = _empty + + def __init__(self, name, kind, default=_empty, annotation=_empty, + _partial_kwarg=False): + + if kind not in (_POSITIONAL_ONLY, _POSITIONAL_OR_KEYWORD, + _VAR_POSITIONAL, _KEYWORD_ONLY, _VAR_KEYWORD): + raise ValueError("invalid value for 'Parameter.kind' attribute") + self._kind = kind + + if default is not _empty: + if kind in (_VAR_POSITIONAL, _VAR_KEYWORD): + msg = '{0} parameters cannot have default values'.format(kind) + raise ValueError(msg) + self._default = default + self._annotation = annotation + + if name is None: + if kind != _POSITIONAL_ONLY: + raise ValueError("None is not a valid name for a " + "non-positional-only parameter") + self._name = name + else: + name = str(name) + if kind != _POSITIONAL_ONLY and not re.match(r'[a-z_]\w*$', name, re.I): + msg = '{0!r} is not a valid parameter name'.format(name) + raise ValueError(msg) + self._name = name + + self._partial_kwarg = _partial_kwarg + + @property + def name(self): + return self._name + + @property + def default(self): + return self._default + + @property + def annotation(self): + return self._annotation + + @property + def kind(self): + return self._kind + + def replace(self, name=_void, kind=_void, annotation=_void, + default=_void, _partial_kwarg=_void): + '''Creates a customized copy of the Parameter.''' + + if name is _void: + name = self._name + + if kind is _void: + kind = self._kind + + if annotation is _void: + annotation = self._annotation + + if default is _void: + default = self._default + + if _partial_kwarg is _void: + _partial_kwarg = self._partial_kwarg + + return type(self)(name, kind, default=default, annotation=annotation, + _partial_kwarg=_partial_kwarg) + + def __str__(self): + kind = self.kind + + formatted = self._name + if kind == _POSITIONAL_ONLY: + if formatted is None: + formatted = '' + formatted = '<{0}>'.format(formatted) + + # Add annotation and default value + if self._annotation is not _empty: + formatted = '{0}:{1}'.format(formatted, + formatannotation(self._annotation)) + + if self._default is not _empty: + formatted = '{0}={1}'.format(formatted, repr(self._default)) + + if kind == _VAR_POSITIONAL: + formatted = '*' + formatted + elif kind == _VAR_KEYWORD: + formatted = '**' + formatted + + return formatted + + def __repr__(self): + return '<{0} at {1:#x} {2!r}>'.format(self.__class__.__name__, + id(self), self.name) + + def __hash__(self): + msg = "unhashable type: '{0}'".format(self.__class__.__name__) + raise TypeError(msg) + + def __eq__(self, other): + return (issubclass(other.__class__, Parameter) and + self._name == other._name and + self._kind == other._kind and + self._default == other._default and + self._annotation == other._annotation) + + def __ne__(self, other): + return not self.__eq__(other) + + +class BoundArguments(object): + '''Result of `Signature.bind` call. Holds the mapping of arguments + to the function's parameters. + + Has the following public attributes: + + * arguments : OrderedDict + An ordered mutable mapping of parameters' names to arguments' values. + Does not contain arguments' default values. + * signature : Signature + The Signature object that created this instance. + * args : tuple + Tuple of positional arguments values. + * kwargs : dict + Dict of keyword arguments values. + ''' + + def __init__(self, signature, arguments): + self.arguments = arguments + self._signature = signature + + @property + def signature(self): + return self._signature + + @property + def args(self): + args = [] + for param_name, param in self._signature.parameters.items(): + if (param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) or + param._partial_kwarg): + # Keyword arguments mapped by 'functools.partial' + # (Parameter._partial_kwarg is True) are mapped + # in 'BoundArguments.kwargs', along with VAR_KEYWORD & + # KEYWORD_ONLY + break + + try: + arg = self.arguments[param_name] + except KeyError: + # We're done here. Other arguments + # will be mapped in 'BoundArguments.kwargs' + break + else: + if param.kind == _VAR_POSITIONAL: + # *args + args.extend(arg) + else: + # plain argument + args.append(arg) + + return tuple(args) + + @property + def kwargs(self): + kwargs = {} + kwargs_started = False + for param_name, param in self._signature.parameters.items(): + if not kwargs_started: + if (param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) or + param._partial_kwarg): + kwargs_started = True + else: + if param_name not in self.arguments: + kwargs_started = True + continue + + if not kwargs_started: + continue + + try: + arg = self.arguments[param_name] + except KeyError: + pass + else: + if param.kind == _VAR_KEYWORD: + # **kwargs + kwargs.update(arg) + else: + # plain keyword argument + kwargs[param_name] = arg + + return kwargs + + def __hash__(self): + msg = "unhashable type: '{0}'".format(self.__class__.__name__) + raise TypeError(msg) + + def __eq__(self, other): + return (issubclass(other.__class__, BoundArguments) and + self.signature == other.signature and + self.arguments == other.arguments) + + def __ne__(self, other): + return not self.__eq__(other) + + +class Signature(object): + '''A Signature object represents the overall signature of a function. + It stores a Parameter object for each parameter accepted by the + function, as well as information specific to the function itself. + + A Signature object has the following public attributes and methods: + + * parameters : OrderedDict + An ordered mapping of parameters' names to the corresponding + Parameter objects (keyword-only arguments are in the same order + as listed in `code.co_varnames`). + * return_annotation : object + The annotation for the return type of the function if specified. + If the function has no annotation for its return type, this + attribute is not set. + * bind(*args, **kwargs) -> BoundArguments + Creates a mapping from positional and keyword arguments to + parameters. + * bind_partial(*args, **kwargs) -> BoundArguments + Creates a partial mapping from positional and keyword arguments + to parameters (simulating 'functools.partial' behavior.) + ''' + + __slots__ = ('_return_annotation', '_parameters') + + _parameter_cls = Parameter + _bound_arguments_cls = BoundArguments + + empty = _empty + + def __init__(self, parameters=None, return_annotation=_empty, + __validate_parameters__=True): + '''Constructs Signature from the given list of Parameter + objects and 'return_annotation'. All arguments are optional. + ''' + + if parameters is None: + params = OrderedDict() + else: + if __validate_parameters__: + params = OrderedDict() + top_kind = _POSITIONAL_ONLY + + for idx, param in enumerate(parameters): + kind = param.kind + if kind < top_kind: + msg = 'wrong parameter order: {0} before {1}' + msg = msg.format(top_kind, param.kind) + raise ValueError(msg) + else: + top_kind = kind + + name = param.name + if name is None: + name = str(idx) + param = param.replace(name=name) + + if name in params: + msg = 'duplicate parameter name: {0!r}'.format(name) + raise ValueError(msg) + params[name] = param + else: + params = OrderedDict(((param.name, param) + for param in parameters)) + + self._parameters = params + self._return_annotation = return_annotation + + @classmethod + def from_function(cls, func): + '''Constructs Signature for the given python function''' + + if not isinstance(func, types.FunctionType): + raise TypeError('{0!r} is not a Python function'.format(func)) + + Parameter = cls._parameter_cls + + # Parameter information. + func_code = func.__code__ + pos_count = func_code.co_argcount + arg_names = func_code.co_varnames + positional = tuple(arg_names[:pos_count]) + keyword_only_count = getattr(func_code, 'co_kwonlyargcount', 0) + keyword_only = arg_names[pos_count:(pos_count + keyword_only_count)] + annotations = getattr(func, '__annotations__', {}) + defaults = func.__defaults__ + kwdefaults = getattr(func, '__kwdefaults__', None) + + if defaults: + pos_default_count = len(defaults) + else: + pos_default_count = 0 + + parameters = [] + + # Non-keyword-only parameters w/o defaults. + non_default_count = pos_count - pos_default_count + for name in positional[:non_default_count]: + annotation = annotations.get(name, _empty) + parameters.append(Parameter(name, annotation=annotation, + kind=_POSITIONAL_OR_KEYWORD)) + + # ... w/ defaults. + for offset, name in enumerate(positional[non_default_count:]): + annotation = annotations.get(name, _empty) + parameters.append(Parameter(name, annotation=annotation, + kind=_POSITIONAL_OR_KEYWORD, + default=defaults[offset])) + + # *args + if func_code.co_flags & 0x04: + name = arg_names[pos_count + keyword_only_count] + annotation = annotations.get(name, _empty) + parameters.append(Parameter(name, annotation=annotation, + kind=_VAR_POSITIONAL)) + + # Keyword-only parameters. + for name in keyword_only: + default = _empty + if kwdefaults is not None: + default = kwdefaults.get(name, _empty) + + annotation = annotations.get(name, _empty) + parameters.append(Parameter(name, annotation=annotation, + kind=_KEYWORD_ONLY, + default=default)) + # **kwargs + if func_code.co_flags & 0x08: + index = pos_count + keyword_only_count + if func_code.co_flags & 0x04: + index += 1 + + name = arg_names[index] + annotation = annotations.get(name, _empty) + parameters.append(Parameter(name, annotation=annotation, + kind=_VAR_KEYWORD)) + + return cls(parameters, + return_annotation=annotations.get('return', _empty), + __validate_parameters__=False) + + @property + def parameters(self): + try: + return types.MappingProxyType(self._parameters) + except AttributeError: + return OrderedDict(self._parameters.items()) + + @property + def return_annotation(self): + return self._return_annotation + + def replace(self, parameters=_void, return_annotation=_void): + '''Creates a customized copy of the Signature. + Pass 'parameters' and/or 'return_annotation' arguments + to override them in the new copy. + ''' + + if parameters is _void: + parameters = self.parameters.values() + + if return_annotation is _void: + return_annotation = self._return_annotation + + return type(self)(parameters, + return_annotation=return_annotation) + + def __hash__(self): + msg = "unhashable type: '{0}'".format(self.__class__.__name__) + raise TypeError(msg) + + def __eq__(self, other): + if (not issubclass(type(other), Signature) or + self.return_annotation != other.return_annotation or + len(self.parameters) != len(other.parameters)): + return False + + other_positions = dict((param, idx) + for idx, param in enumerate(other.parameters.keys())) + + for idx, (param_name, param) in enumerate(self.parameters.items()): + if param.kind == _KEYWORD_ONLY: + try: + other_param = other.parameters[param_name] + except KeyError: + return False + else: + if param != other_param: + return False + else: + try: + other_idx = other_positions[param_name] + except KeyError: + return False + else: + if (idx != other_idx or + param != other.parameters[param_name]): + return False + + return True + + def __ne__(self, other): + return not self.__eq__(other) + + def _bind(self, args, kwargs, partial=False): + '''Private method. Don't use directly.''' + + arguments = OrderedDict() + + parameters = iter(self.parameters.values()) + parameters_ex = () + arg_vals = iter(args) + + if partial: + # Support for binding arguments to 'functools.partial' objects. + # See 'functools.partial' case in 'signature()' implementation + # for details. + for param_name, param in self.parameters.items(): + if (param._partial_kwarg and param_name not in kwargs): + # Simulating 'functools.partial' behavior + kwargs[param_name] = param.default + + while True: + # Let's iterate through the positional arguments and corresponding + # parameters + try: + arg_val = next(arg_vals) + except StopIteration: + # No more positional arguments + try: + param = next(parameters) + except StopIteration: + # No more parameters. That's it. Just need to check that + # we have no `kwargs` after this while loop + break + else: + if param.kind == _VAR_POSITIONAL: + # That's OK, just empty *args. Let's start parsing + # kwargs + break + elif param.name in kwargs: + if param.kind == _POSITIONAL_ONLY: + msg = '{arg!r} parameter is positional only, ' \ + 'but was passed as a keyword' + msg = msg.format(arg=param.name) + raise TypeError(msg) + parameters_ex = (param,) + break + elif (param.kind == _VAR_KEYWORD or + param.default is not _empty): + # That's fine too - we have a default value for this + # parameter. So, lets start parsing `kwargs`, starting + # with the current parameter + parameters_ex = (param,) + break + else: + if partial: + parameters_ex = (param,) + break + else: + msg = '{arg!r} parameter lacking default value' + msg = msg.format(arg=param.name) + raise TypeError(msg) + else: + # We have a positional argument to process + try: + param = next(parameters) + except StopIteration: + raise TypeError('too many positional arguments') + else: + if param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY): + # Looks like we have no parameter for this positional + # argument + raise TypeError('too many positional arguments') + + if param.kind == _VAR_POSITIONAL: + # We have an '*args'-like argument, let's fill it with + # all positional arguments we have left and move on to + # the next phase + values = [arg_val] + values.extend(arg_vals) + arguments[param.name] = tuple(values) + break + + if param.name in kwargs: + raise TypeError('multiple values for argument ' + '{arg!r}'.format(arg=param.name)) + + arguments[param.name] = arg_val + + # Now, we iterate through the remaining parameters to process + # keyword arguments + kwargs_param = None + for param in itertools.chain(parameters_ex, parameters): + if param.kind == _POSITIONAL_ONLY: + # This should never happen in case of a properly built + # Signature object (but let's have this check here + # to ensure correct behaviour just in case) + raise TypeError('{arg!r} parameter is positional only, ' + 'but was passed as a keyword'. \ + format(arg=param.name)) + + if param.kind == _VAR_KEYWORD: + # Memorize that we have a '**kwargs'-like parameter + kwargs_param = param + continue + + param_name = param.name + try: + arg_val = kwargs.pop(param_name) + except KeyError: + # We have no value for this parameter. It's fine though, + # if it has a default value, or it is an '*args'-like + # parameter, left alone by the processing of positional + # arguments. + if (not partial and param.kind != _VAR_POSITIONAL and + param.default is _empty): + raise TypeError('{arg!r} parameter lacking default value'. \ + format(arg=param_name)) + + else: + arguments[param_name] = arg_val + + if kwargs: + if kwargs_param is not None: + # Process our '**kwargs'-like parameter + arguments[kwargs_param.name] = kwargs + else: + raise TypeError('too many keyword arguments') + + return self._bound_arguments_cls(self, arguments) + + def bind(self, *args, **kwargs): + '''Get a BoundArguments object, that maps the passed `args` + and `kwargs` to the function's signature. Raises `TypeError` + if the passed arguments can not be bound. + ''' + return self._bind(args, kwargs) + + def bind_partial(self, *args, **kwargs): + '''Get a BoundArguments object, that partially maps the + passed `args` and `kwargs` to the function's signature. + Raises `TypeError` if the passed arguments can not be bound. + ''' + return self._bind(args, kwargs, partial=True) + + def __str__(self): + result = [] + render_kw_only_separator = True + for idx, param in enumerate(self.parameters.values()): + formatted = str(param) + + kind = param.kind + if kind == _VAR_POSITIONAL: + # OK, we have an '*args'-like parameter, so we won't need + # a '*' to separate keyword-only arguments + render_kw_only_separator = False + elif kind == _KEYWORD_ONLY and render_kw_only_separator: + # We have a keyword-only parameter to render and we haven't + # rendered an '*args'-like parameter before, so add a '*' + # separator to the parameters list ("foo(arg1, *, arg2)" case) + result.append('*') + # This condition should be only triggered once, so + # reset the flag + render_kw_only_separator = False + + result.append(formatted) + + rendered = '({0})'.format(', '.join(result)) + + if self.return_annotation is not _empty: + anno = formatannotation(self.return_annotation) + rendered += ' -> {0}'.format(anno) + + return rendered diff --git a/ot/gromov.py b/ot/gromov.py index 2a2387315..e03fa5b0c 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -595,7 +595,7 @@ def entropic_gromov_wasserstein2(C1, C2, p, q, loss_fun, epsilon, gw, logv = entropic_gromov_wasserstein( C1, C2, p, q, loss_fun, epsilon, max_iter, tol, verbose, log=True) - log['T'] = gw + logv['T'] = gw if log: return logv['gw_dist'], logv diff --git a/ot/utils.py b/ot/utils.py index 16862ead8..17983f2e1 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -15,7 +15,10 @@ from scipy.spatial.distance import cdist import sys import warnings - +try: + from inspect import signature +except ImportError: + from .externals.funcsigs import signature __time_tic_toc = time.time() @@ -335,10 +338,7 @@ class BaseEstimator(object): @classmethod def _get_param_names(cls): """Get parameter names for the estimator""" - try: - from inspect import signature - except ImportError: - from .externals.funcsigs import signature + # fetch the constructor or the original constructor before # deprecation wrapping if any init = getattr(cls.__init__, 'deprecated_original', cls.__init__) From 55aaf7874c651235d44c34b89337df7694e55014 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 21 Mar 2018 10:08:17 +0100 Subject: [PATCH 11/24] add test gromov + debug sklearn Basestimator --- test/test_da.py | 4 ++-- test/test_gromov.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/test/test_da.py b/test/test_da.py index a9d6d349c..3022721c4 100644 --- a/test/test_da.py +++ b/test/test_da.py @@ -326,8 +326,8 @@ def test_mapping_transport_class(): """test_mapping_transport """ - ns = 150 - nt = 200 + ns = 60 + nt = 120 Xs, ys = get_data_classif('3gauss', ns) Xt, yt = get_data_classif('3gauss2', nt) diff --git a/test/test_gromov.py b/test/test_gromov.py index 625e62a96..0384ee11b 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -36,6 +36,18 @@ def test_gromov(): np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov + gw, log = ot.gromov.gromov_wasserstein2(C1, C2, p, q, 'kl_loss', log=True) + + G = log['T'] + + np.testing.assert_allclose(gw, 0, atol=1e-04, rtol=1e-4) + + # check constratints + np.testing.assert_allclose( + p, G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, G.sum(0), atol=1e-04) # cf convergence gromov + def test_entropic_gromov(): n_samples = 50 # nb samples @@ -64,3 +76,16 @@ def test_entropic_gromov(): p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov + + gw, log = ot.gromov.entropic_gromov_wasserstein2( + C1, C2, p, q, 'kl_loss', epsilon=1e-2, log=True) + + G = log['T'] + + np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e1) + + # check constratints + np.testing.assert_allclose( + p, G.sum(1), atol=1e-04) # cf convergence gromov + np.testing.assert_allclose( + q, G.sum(0), atol=1e-04) # cf convergence gromov From 64ef33d09906a1aebd3c8294ffd7720475ab926b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 21 Mar 2018 10:13:16 +0100 Subject: [PATCH 12/24] aupdate gromov + autopep8 externals --- ot/externals/funcsigs.py | 56 +++++++++++++++++++++------------------- test/test_gromov.py | 2 +- 2 files changed, 30 insertions(+), 28 deletions(-) diff --git a/ot/externals/funcsigs.py b/ot/externals/funcsigs.py index 4e684690b..82a55dac5 100644 --- a/ot/externals/funcsigs.py +++ b/ot/externals/funcsigs.py @@ -29,7 +29,7 @@ def formatannotation(annotation, base_module=None): if isinstance(annotation, type): if annotation.__module__ in ('builtins', '__builtin__', base_module): return annotation.__name__ - return annotation.__module__+'.'+annotation.__name__ + return annotation.__module__ + '.' + annotation.__name__ return repr(annotation) @@ -127,7 +127,7 @@ def signature(obj): _partial_kwarg=True) elif (param.kind not in (_VAR_KEYWORD, _VAR_POSITIONAL) and - not param._partial_kwarg): + not param._partial_kwarg): new_params.pop(arg_name) return sig.replace(parameters=new_params.values()) @@ -170,7 +170,8 @@ def signature(obj): msg = 'no signature found for builtin function {0!r}'.format(obj) raise ValueError(msg) - raise ValueError('callable {0!r} is not supported by signature'.format(obj)) + raise ValueError( + 'callable {0!r} is not supported by signature'.format(obj)) class _void(object): @@ -194,11 +195,11 @@ def __repr__(self): return '<_ParameterKind: {0!r}>'.format(self._name) -_POSITIONAL_ONLY = _ParameterKind(0, name='POSITIONAL_ONLY') -_POSITIONAL_OR_KEYWORD = _ParameterKind(1, name='POSITIONAL_OR_KEYWORD') -_VAR_POSITIONAL = _ParameterKind(2, name='VAR_POSITIONAL') -_KEYWORD_ONLY = _ParameterKind(3, name='KEYWORD_ONLY') -_VAR_KEYWORD = _ParameterKind(4, name='VAR_KEYWORD') +_POSITIONAL_ONLY = _ParameterKind(0, name='POSITIONAL_ONLY') +_POSITIONAL_OR_KEYWORD = _ParameterKind(1, name='POSITIONAL_OR_KEYWORD') +_VAR_POSITIONAL = _ParameterKind(2, name='VAR_POSITIONAL') +_KEYWORD_ONLY = _ParameterKind(3, name='KEYWORD_ONLY') +_VAR_KEYWORD = _ParameterKind(4, name='VAR_KEYWORD') class Parameter(object): @@ -223,11 +224,11 @@ class Parameter(object): __slots__ = ('_name', '_kind', '_default', '_annotation', '_partial_kwarg') - POSITIONAL_ONLY = _POSITIONAL_ONLY - POSITIONAL_OR_KEYWORD = _POSITIONAL_OR_KEYWORD - VAR_POSITIONAL = _VAR_POSITIONAL - KEYWORD_ONLY = _KEYWORD_ONLY - VAR_KEYWORD = _VAR_KEYWORD + POSITIONAL_ONLY = _POSITIONAL_ONLY + POSITIONAL_OR_KEYWORD = _POSITIONAL_OR_KEYWORD + VAR_POSITIONAL = _VAR_POSITIONAL + KEYWORD_ONLY = _KEYWORD_ONLY + VAR_KEYWORD = _VAR_KEYWORD empty = _empty @@ -253,7 +254,8 @@ def __init__(self, name, kind, default=_empty, annotation=_empty, self._name = name else: name = str(name) - if kind != _POSITIONAL_ONLY and not re.match(r'[a-z_]\w*$', name, re.I): + if kind != _POSITIONAL_ONLY and not re.match( + r'[a-z_]\w*$', name, re.I): msg = '{0!r} is not a valid parameter name'.format(name) raise ValueError(msg) self._name = name @@ -310,7 +312,7 @@ def __str__(self): # Add annotation and default value if self._annotation is not _empty: formatted = '{0}:{1}'.format(formatted, - formatannotation(self._annotation)) + formatannotation(self._annotation)) if self._default is not _empty: formatted = '{0}={1}'.format(formatted, repr(self._default)) @@ -324,7 +326,7 @@ def __str__(self): def __repr__(self): return '<{0} at {1:#x} {2!r}>'.format(self.__class__.__name__, - id(self), self.name) + id(self), self.name) def __hash__(self): msg = "unhashable type: '{0}'".format(self.__class__.__name__) @@ -371,7 +373,7 @@ def args(self): args = [] for param_name, param in self._signature.parameters.items(): if (param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) or - param._partial_kwarg): + param._partial_kwarg): # Keyword arguments mapped by 'functools.partial' # (Parameter._partial_kwarg is True) are mapped # in 'BoundArguments.kwargs', along with VAR_KEYWORD & @@ -401,7 +403,7 @@ def kwargs(self): for param_name, param in self._signature.parameters.items(): if not kwargs_started: if (param.kind in (_VAR_KEYWORD, _KEYWORD_ONLY) or - param._partial_kwarg): + param._partial_kwarg): kwargs_started = True else: if param_name not in self.arguments: @@ -501,7 +503,7 @@ def __init__(self, parameters=None, return_annotation=_empty, params[name] = param else: params = OrderedDict(((param.name, param) - for param in parameters)) + for param in parameters)) self._parameters = params self._return_annotation = return_annotation @@ -611,12 +613,12 @@ def __hash__(self): def __eq__(self, other): if (not issubclass(type(other), Signature) or - self.return_annotation != other.return_annotation or - len(self.parameters) != len(other.parameters)): + self.return_annotation != other.return_annotation or + len(self.parameters) != len(other.parameters)): return False other_positions = dict((param, idx) - for idx, param in enumerate(other.parameters.keys())) + for idx, param in enumerate(other.parameters.keys())) for idx, (param_name, param) in enumerate(self.parameters.items()): if param.kind == _KEYWORD_ONLY: @@ -634,7 +636,7 @@ def __eq__(self, other): return False else: if (idx != other_idx or - param != other.parameters[param_name]): + param != other.parameters[param_name]): return False return True @@ -687,7 +689,7 @@ def _bind(self, args, kwargs, partial=False): parameters_ex = (param,) break elif (param.kind == _VAR_KEYWORD or - param.default is not _empty): + param.default is not _empty): # That's fine too - we have a default value for this # parameter. So, lets start parsing `kwargs`, starting # with the current parameter @@ -737,7 +739,7 @@ def _bind(self, args, kwargs, partial=False): # Signature object (but let's have this check here # to ensure correct behaviour just in case) raise TypeError('{arg!r} parameter is positional only, ' - 'but was passed as a keyword'. \ + 'but was passed as a keyword'. format(arg=param.name)) if param.kind == _VAR_KEYWORD: @@ -754,8 +756,8 @@ def _bind(self, args, kwargs, partial=False): # parameter, left alone by the processing of positional # arguments. if (not partial and param.kind != _VAR_POSITIONAL and - param.default is _empty): - raise TypeError('{arg!r} parameter lacking default value'. \ + param.default is _empty): + raise TypeError('{arg!r} parameter lacking default value'. format(arg=param_name)) else: diff --git a/test/test_gromov.py b/test/test_gromov.py index 0384ee11b..0dfd54e24 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -40,7 +40,7 @@ def test_gromov(): G = log['T'] - np.testing.assert_allclose(gw, 0, atol=1e-04, rtol=1e-4) + np.testing.assert_allclose(gw, 0, atol=1e-2, rtol=1e-2) # check constratints np.testing.assert_allclose( From 7095e03eb339bcf32d91c5a8857ecc3f3d0c45c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 21 Mar 2018 10:28:19 +0100 Subject: [PATCH 13/24] gtomov barycenter tests --- test/test_gromov.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/test/test_gromov.py b/test/test_gromov.py index 0dfd54e24..d865380be 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -40,7 +40,7 @@ def test_gromov(): G = log['T'] - np.testing.assert_allclose(gw, 0, atol=1e-2, rtol=1e-2) + np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) # check constratints np.testing.assert_allclose( @@ -82,10 +82,37 @@ def test_entropic_gromov(): G = log['T'] - np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e1) + np.testing.assert_allclose(gw, 0, atol=1e-1, rtol=1e-1) # check constratints np.testing.assert_allclose( p, G.sum(1), atol=1e-04) # cf convergence gromov np.testing.assert_allclose( q, G.sum(0), atol=1e-04) # cf convergence gromov + + +def test_gromov_barycenter(): + + ns = 50 + nt = 60 + + Xs, ys = ot.datasets.get_data_classif('3gauss', ns) + Xt, yt = ot.datasets.get_data_classif('3gauss2', nt) + + C1 = ot.dist(Xs) + C2 = ot.dist(Xt) + + n_samples = 3 + Cb = ot.gromov.gromov_barycenters(n_samples, [C1, C2], + [ot.unif(ns), ot.unif(nt) + ], ot.unif(n_samples), [.5, .5], + 'square_loss', # 5e-4, + max_iter=100, tol=1e-3) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + + Cb2 = ot.gromov.gromov_barycenters(n_samples, [C1, C2], + [ot.unif(ns), ot.unif(nt) + ], ot.unif(n_samples), [.5, .5], + 'kl_loss', # 5e-4, + max_iter=100, tol=1e-3) + np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) From 63fd11e8bfd45b163b313c7ad874ef608587fb68 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 21 Mar 2018 10:32:18 +0100 Subject: [PATCH 14/24] add entropic gromov test for 90+% corerage --- ot/gromov.py | 2 +- test/test_gromov.py | 27 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/ot/gromov.py b/ot/gromov.py index e03fa5b0c..65b2e29f9 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -613,7 +613,7 @@ def entropic_gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, epsilon, The function solves the following optimization problem: .. math:: - C = argmin_C\in R^NxN \sum_s \lambda_s GW(C,Cs,p,ps) + C = argmin_C\in R^{NxN} \sum_s \lambda_s GW(C,Cs,p,ps) Where : diff --git a/test/test_gromov.py b/test/test_gromov.py index d865380be..bb23469f4 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -116,3 +116,30 @@ def test_gromov_barycenter(): 'kl_loss', # 5e-4, max_iter=100, tol=1e-3) np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) + + +def test_gromov_entropic_barycenter(): + + ns = 50 + nt = 60 + + Xs, ys = ot.datasets.get_data_classif('3gauss', ns) + Xt, yt = ot.datasets.get_data_classif('3gauss2', nt) + + C1 = ot.dist(Xs) + C2 = ot.dist(Xt) + + n_samples = 3 + Cb = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2], + [ot.unif(ns), ot.unif(nt) + ], ot.unif(n_samples), [.5, .5], + 'square_loss', 1e-3, + max_iter=100, tol=1e-3) + np.testing.assert_allclose(Cb.shape, (n_samples, n_samples)) + + Cb2 = ot.gromov.entropic_gromov_barycenters(n_samples, [C1, C2], + [ot.unif(ns), ot.unif(nt) + ], ot.unif(n_samples), [.5, .5], + 'kl_loss', 1e-3, + max_iter=100, tol=1e-3) + np.testing.assert_allclose(Cb2.shape, (n_samples, n_samples)) From 1262563ef24c9ab0213f616ef01e1c80eb977176 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 21 Mar 2018 10:50:16 +0100 Subject: [PATCH 15/24] update readme + doc --- ot/da.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/ot/da.py b/ot/da.py index f789396a6..5d62d43ea 100644 --- a/ot/da.py +++ b/ot/da.py @@ -643,7 +643,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, The function estimate the optimal linear operator that align the two empirical distributions. This is equivalent to estimating the closed form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)` - and :math:`N(\mu_t,\Sigma_t)` as proposed in [14]. + and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in remark 2.29 in [15]. The linear operator from source to target :math:`M` @@ -692,6 +692,9 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of distributions", Journal of Optimization Theory and Applications Vol 43, 1984 + + .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. """ @@ -1290,7 +1293,8 @@ class LinearTransport(BaseTransport): The function estimate the optimal linear operator that align the two empirical distributions. This is equivalent to estimating the closed form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)` - and :math:`N(\mu_t,\Sigma_t)` as proposed in [14]. + and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in + remark 2.29 in [15]. The linear operator from source to target :math:`M` @@ -1314,6 +1318,15 @@ class LinearTransport(BaseTransport): log : bool, optional record log if True + References + ---------- + + .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of + distributions", Journal of Optimization Theory and Applications + Vol 43, 1984 + + .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. """ From 0ce1a5ed14e44cd0c596fc0393eceeb8199d20d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 21 Mar 2018 10:50:45 +0100 Subject: [PATCH 16/24] update doc --- README.md | 4 ++++ docs/source/readme.rst | 17 ++++++++++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 8a9d7faa5..fb7ab2a93 100644 --- a/README.md +++ b/README.md @@ -206,3 +206,7 @@ You can also post bug reports and feature requests in Github issues. Make sure t [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). 2016. [13] Mémoli, Facundo. [Gromov–Wasserstein distances and the metric approach to object matching](https://media.adelaide.edu.au/acvt/Publications/2011/2011-Gromov%E2%80%93Wasserstein%20Distances%20and%20the%20Metric%20Approach%20to%20Object%20Matching.pdf). Foundations of computational mathematics 11.4 (2011): 417-487. + +[14] Knott, M. and Smith, C. S. [On the optimal mapping of distributions](https://link.springer.com/article/10.1007/BF00934745), Journal of Optimization Theory and Applications Vol 43, 1984. + +[15] Peyré, G., & Cuturi, M. (2017). [Computational Optimal Transport](https://arxiv.org/pdf/1803.00567.pdf) , 2018. diff --git a/docs/source/readme.rst b/docs/source/readme.rst index 347bde20c..647f7e85b 100644 --- a/docs/source/readme.rst +++ b/docs/source/readme.rst @@ -1,8 +1,8 @@ POT: Python Optimal Transport ============================= -|PyPI version| |Build Status| |Documentation Status| |Anaconda Cloud| -|License| |Anaconda downloads| +|PyPI version| |Anaconda Cloud| |Build Status| |Documentation Status| +|Anaconda downloads| |License| This open source Python library provide several solvers for optimization problems related to Optimal Transport for signal, image processing and @@ -311,15 +311,22 @@ approach to object matching `__. Foundations of computational mathematics 11.4 (2011): 417-487. +[14] Knott, M. and Smith, C. S. `On the optimal mapping of +distributions `__, +Journal of Optimization Theory and Applications Vol 43, 1984. + +[15] Peyré, G., & Cuturi, M. (2017). `Computational Optimal +Transport `__ , 2018. + .. |PyPI version| image:: https://badge.fury.io/py/POT.svg :target: https://badge.fury.io/py/POT +.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg + :target: https://anaconda.org/conda-forge/pot .. |Build Status| image:: https://travis-ci.org/rflamary/POT.svg?branch=master :target: https://travis-ci.org/rflamary/POT .. |Documentation Status| image:: https://readthedocs.org/projects/pot/badge/?version=latest :target: http://pot.readthedocs.io/en/latest/?badge=latest -.. |Anaconda Cloud| image:: https://anaconda.org/conda-forge/pot/badges/version.svg +.. |Anaconda downloads| image:: https://anaconda.org/conda-forge/pot/badges/downloads.svg :target: https://anaconda.org/conda-forge/pot .. |License| image:: https://anaconda.org/conda-forge/pot/badges/license.svg :target: https://github.com/rflamary/POT/blob/master/LICENSE -.. |Anaconda downloads| image:: https://anaconda.org/conda-forge/pot/badges/downloads.svg - :target: https://anaconda.org/conda-forge/pot From 83c706cb6b1c9eb6ca033c58532b85c13b5d40f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 21 Mar 2018 10:54:00 +0100 Subject: [PATCH 17/24] pep cleanup --- ot/da.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ot/da.py b/ot/da.py index 5d62d43ea..cdebc91bc 100644 --- a/ot/da.py +++ b/ot/da.py @@ -692,8 +692,8 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of distributions", Journal of Optimization Theory and Applications Vol 43, 1984 - - .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + + .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal Transport", 2018. @@ -1293,7 +1293,7 @@ class LinearTransport(BaseTransport): The function estimate the optimal linear operator that align the two empirical distributions. This is equivalent to estimating the closed form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)` - and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in + and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in remark 2.29 in [15]. The linear operator from source to target :math:`M` @@ -1324,8 +1324,8 @@ class LinearTransport(BaseTransport): .. [14] Knott, M. and Smith, C. S. "On the optimal mapping of distributions", Journal of Optimization Theory and Applications Vol 43, 1984 - - .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + + .. [15] Peyré, G., & Cuturi, M. (2017). "Computational Optimal Transport", 2018. """ From 69c7d1cb64a5628c69a3c1533991741bcd91f96b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 21 Mar 2018 10:58:20 +0100 Subject: [PATCH 18/24] pep8 unused variable --- ot/externals/funcsigs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ot/externals/funcsigs.py b/ot/externals/funcsigs.py index 82a55dac5..c73fdc96d 100644 --- a/ot/externals/funcsigs.py +++ b/ot/externals/funcsigs.py @@ -99,7 +99,7 @@ def signature(obj): partial_keywords = obj.keywords or {} try: ba = sig.bind_partial(*partial_args, **partial_keywords) - except TypeError as ex: + except TypeError: msg = 'partial object {0!r} has incorrect arguments'.format(obj) raise ValueError(msg) From 7681db5c19817cfd003cea9ffdd95fedb9b00650 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 21 Mar 2018 11:03:06 +0100 Subject: [PATCH 19/24] update reame --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index fb7ab2a93..353fe32df 100644 --- a/README.md +++ b/README.md @@ -13,14 +13,14 @@ This open source Python library provide several solvers for optimization problem It provides the following solvers: -* OT solver for the linear program/ Earth Movers Distance [1]. +* OT Network Flow solver for the linear program/ Earth Movers Distance [1]. * Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (required cudamat). * Bregman projections for Wasserstein barycenter [3] and unmixing [4]. * Optimal transport for domain adaptation with group lasso regularization [5] * Conditional gradient [6] and Generalized conditional gradient for regularized OT [7]. -* Joint OT matrix and mapping estimation [8]. +* Linear OT [14] and Joint OT matrix and mapping estimation [8]. * Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt). -* Gromov-Wasserstein distances and barycenters [12] +* Gromov-Wasserstein distances and barycenters ([13] and regularized [12]) Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. From c30519a4c67a056c85d7897f198e2fb34c584755 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 9 May 2018 13:06:02 +0200 Subject: [PATCH 20/24] cleanup Makefile --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 468d042c0..95714b8ad 100644 --- a/Makefile +++ b/Makefile @@ -44,11 +44,11 @@ test : FORCE pep8 $(PYTHON) -m pytest -v test/ --cov=ot --cov-report html:cov_html pytest : FORCE - python -m py.test -v test/ --cov=ot + $(PYTHON) -m py.test -v test/ --cov=ot uploadpypi : #python setup.py register - python setup.py sdist upload -r pypi + $(PYTHON) setup.py sdist upload -r pypi rdoc : pandoc --from=markdown --to=rst --output=docs/source/readme.rst README.md From d5ea28b22ab94a13f676ffc0ed862887921c2efc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 9 May 2018 13:06:46 +0200 Subject: [PATCH 21/24] correct ref 15 in readme --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 353fe32df..65ee7100f 100644 --- a/README.md +++ b/README.md @@ -209,4 +209,4 @@ You can also post bug reports and feature requests in Github issues. Make sure t [14] Knott, M. and Smith, C. S. [On the optimal mapping of distributions](https://link.springer.com/article/10.1007/BF00934745), Journal of Optimization Theory and Applications Vol 43, 1984. -[15] Peyré, G., & Cuturi, M. (2017). [Computational Optimal Transport](https://arxiv.org/pdf/1803.00567.pdf) , 2018. +[15] Peyré, G., & Cuturi, M. (2018). [Computational Optimal Transport](https://arxiv.org/pdf/1803.00567.pdf) . From e26e69f11498a85148f0df9776c7fb0fca4545f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 9 May 2018 13:07:01 +0200 Subject: [PATCH 22/24] update documentation wrt readme file --- docs/source/readme.rst | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/source/readme.rst b/docs/source/readme.rst index 647f7e85b..d73d293e1 100644 --- a/docs/source/readme.rst +++ b/docs/source/readme.rst @@ -10,7 +10,8 @@ machine learning. It provides the following solvers: -- OT solver for the linear program/ Earth Movers Distance [1]. +- OT Network Flow solver for the linear program/ Earth Movers Distance + [1]. - Entropic regularization OT solver with Sinkhorn Knopp Algorithm [2] and stabilized version [9][10] with optional GPU implementation (required cudamat). @@ -19,10 +20,11 @@ It provides the following solvers: regularization [5] - Conditional gradient [6] and Generalized conditional gradient for regularized OT [7]. -- Joint OT matrix and mapping estimation [8]. +- Linear OT [14] and Joint OT matrix and mapping estimation [8]. - Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt). -- Gromov-Wasserstein distances and barycenters [12] +- Gromov-Wasserstein distances and barycenters ([13] and regularized + [12]) Some demonstrations (both in Python and Jupyter Notebook format) are available in the examples folder. @@ -315,8 +317,8 @@ Foundations of computational mathematics 11.4 (2011): 417-487. distributions `__, Journal of Optimization Theory and Applications Vol 43, 1984. -[15] Peyré, G., & Cuturi, M. (2017). `Computational Optimal -Transport `__ , 2018. +[15] Peyré, G., & Cuturi, M. (2018). `Computational Optimal +Transport `__ . .. |PyPI version| image:: https://badge.fury.io/py/POT.svg :target: https://badge.fury.io/py/POT From 0a9763ce0e83106daa322566398218aa4a297fe1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 9 May 2018 13:08:53 +0200 Subject: [PATCH 23/24] cleanup reference years in readme --- README.md | 8 ++++---- docs/source/readme.rst | 20 ++++++++++---------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 65ee7100f..6b7cff03c 100644 --- a/README.md +++ b/README.md @@ -195,7 +195,7 @@ You can also post bug reports and feature requests in Github issues. Make sure t [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). [Generalized conditional gradient: analysis of convergence and applications](https://arxiv.org/pdf/1510.06567.pdf). arXiv preprint arXiv:1510.06567. -[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, [Mapping estimation for discrete optimal transport](http://remi.flamary.com/biblio/perrot2016mapping.pdf), Neural Information Processing Systems (NIPS), 2016. +[8] M. Perrot, N. Courty, R. Flamary, A. Habrard (2016), [Mapping estimation for discrete optimal transport](http://remi.flamary.com/biblio/perrot2016mapping.pdf), Neural Information Processing Systems (NIPS). [9] Schmitzer, B. (2016). [Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems](https://arxiv.org/pdf/1610.06519.pdf). arXiv preprint arXiv:1610.06519. @@ -203,10 +203,10 @@ You can also post bug reports and feature requests in Github issues. Make sure t [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). [Wasserstein Discriminant Analysis](https://arxiv.org/pdf/1608.08063.pdf). arXiv preprint arXiv:1608.08063. -[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). 2016. +[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). -[13] Mémoli, Facundo. [Gromov–Wasserstein distances and the metric approach to object matching](https://media.adelaide.edu.au/acvt/Publications/2011/2011-Gromov%E2%80%93Wasserstein%20Distances%20and%20the%20Metric%20Approach%20to%20Object%20Matching.pdf). Foundations of computational mathematics 11.4 (2011): 417-487. +[13] Mémoli, Facundo (2011). [Gromov–Wasserstein distances and the metric approach to object matching](https://media.adelaide.edu.au/acvt/Publications/2011/2011-Gromov%E2%80%93Wasserstein%20Distances%20and%20the%20Metric%20Approach%20to%20Object%20Matching.pdf). Foundations of computational mathematics 11.4 : 417-487. -[14] Knott, M. and Smith, C. S. [On the optimal mapping of distributions](https://link.springer.com/article/10.1007/BF00934745), Journal of Optimization Theory and Applications Vol 43, 1984. +[14] Knott, M. and Smith, C. S. (1984).[On the optimal mapping of distributions](https://link.springer.com/article/10.1007/BF00934745), Journal of Optimization Theory and Applications Vol 43. [15] Peyré, G., & Cuturi, M. (2018). [Computational Optimal Transport](https://arxiv.org/pdf/1803.00567.pdf) . diff --git a/docs/source/readme.rst b/docs/source/readme.rst index d73d293e1..725c207e0 100644 --- a/docs/source/readme.rst +++ b/docs/source/readme.rst @@ -283,10 +283,10 @@ conditional gradient: analysis of convergence and applications `__. arXiv preprint arXiv:1510.06567. -[8] M. Perrot, N. Courty, R. Flamary, A. Habrard, `Mapping estimation -for discrete optimal +[8] M. Perrot, N. Courty, R. Flamary, A. Habrard (2016), `Mapping +estimation for discrete optimal transport `__, -Neural Information Processing Systems (NIPS), 2016. +Neural Information Processing Systems (NIPS). [9] Schmitzer, B. (2016). `Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport @@ -303,19 +303,19 @@ arXiv:1607.05816. Analysis `__. arXiv preprint arXiv:1608.08063. -[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon, +[12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), `Gromov-Wasserstein averaging of kernel and distance matrices `__ -International Conference on Machine Learning (ICML). 2016. +International Conference on Machine Learning (ICML). -[13] Mémoli, Facundo. `Gromov–Wasserstein distances and the metric -approach to object +[13] Mémoli, Facundo (2011). `Gromov–Wasserstein distances and the +metric approach to object matching `__. -Foundations of computational mathematics 11.4 (2011): 417-487. +Foundations of computational mathematics 11.4 : 417-487. -[14] Knott, M. and Smith, C. S. `On the optimal mapping of +[14] Knott, M. and Smith, C. S. (1984).`On the optimal mapping of distributions `__, -Journal of Optimization Theory and Applications Vol 43, 1984. +Journal of Optimization Theory and Applications Vol 43. [15] Peyré, G., & Cuturi, M. (2018). `Computational Optimal Transport `__ . From 0496e2b1b2c2f4ea2d7f313ccf58c612efaa70bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Wed, 9 May 2018 13:12:06 +0200 Subject: [PATCH 24/24] doc typos in linear map function --- ot/da.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ot/da.py b/ot/da.py index cdebc91bc..b83d67ec5 100644 --- a/ot/da.py +++ b/ot/da.py @@ -640,9 +640,9 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, wt=None, bias=True, log=False): """ return OT linear operator between samples - The function estimate the optimal linear operator that align the two + The function estimates the optimal linear operator that aligns the two empirical distributions. This is equivalent to estimating the closed - form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)` + form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)` and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in remark 2.29 in [15]. The linear operator from source to target :math:`M` @@ -665,7 +665,7 @@ def OT_mapping_linear(xs, xt, reg=1e-6, ws=None, xt : np.ndarray (nt,d) samples in the target domain reg : float,optional - regularization added to the daigonals of convariances (>0) + regularization added to the diagonals of convariances (>0) ws : np.ndarray (ns,1), optional weights for the source samples wt : np.ndarray (ns,1), optional @@ -1290,9 +1290,9 @@ class label class LinearTransport(BaseTransport): """ OT linear operator between empirical distributions - The function estimate the optimal linear operator that align the two + The function estimates the optimal linear operator that aligns the two empirical distributions. This is equivalent to estimating the closed - form mapping between two Gaussian distribution :math:`N(\mu_s,\Sigma_s)` + form mapping between two Gaussian distributions :math:`N(\mu_s,\Sigma_s)` and :math:`N(\mu_t,\Sigma_t)` as proposed in [14] and discussed in remark 2.29 in [15].