-
Notifications
You must be signed in to change notification settings - Fork 547
Linear mapping + tests #42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
cb739f6
add linear mapping function
rflamary 8fc9fce
add class LinearTransport
rflamary c104623
passing tests
rflamary 4fc9ccc
better example+test
rflamary 88a81c3
makefile update
rflamary 287c659
update example
rflamary 6fdf5de
add linear mapping test + autopep8
rflamary 5efdf00
add test linear mapping class
rflamary fc9923d
add tests for ot.uils
rflamary 927395b
add externals for function signature
rflamary 55aaf78
add test gromov + debug sklearn Basestimator
rflamary 64ef33d
aupdate gromov + autopep8 externals
rflamary 7095e03
gtomov barycenter tests
rflamary 63fd11e
add entropic gromov test for 90+% corerage
rflamary 1262563
update readme + doc
rflamary 0ce1a5e
update doc
rflamary 83c706c
pep cleanup
rflamary 69c7d1c
pep8 unused variable
rflamary 7681db5
update reame
rflamary c30519a
cleanup Makefile
rflamary d5ea28b
correct ref 15 in readme
rflamary e26e69f
update documentation wrt readme file
rflamary 0a9763c
cleanup reference years in readme
rflamary 0496e2b
doc typos in linear map function
rflamary File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| #!/usr/bin/env python3 | ||
| # -*- coding: utf-8 -*- | ||
| """ | ||
| Created on Tue Mar 20 14:31:15 2018 | ||
|
|
||
| @author: rflamary | ||
| """ | ||
|
|
||
| import numpy as np | ||
| import pylab as pl | ||
| import ot | ||
|
|
||
| ############################################################################## | ||
| # Generate data | ||
| # ------------- | ||
|
|
||
| n = 1000 | ||
| d = 2 | ||
| sigma = .1 | ||
|
|
||
| # source samples | ||
| angles = np.random.rand(n, 1) * 2 * np.pi | ||
| xs = np.concatenate((np.sin(angles), np.cos(angles)), | ||
| axis=1) + sigma * np.random.randn(n, 2) | ||
| xs[:n // 2, 1] += 2 | ||
|
|
||
|
|
||
| # target samples | ||
| anglet = np.random.rand(n, 1) * 2 * np.pi | ||
| xt = np.concatenate((np.sin(anglet), np.cos(anglet)), | ||
| axis=1) + sigma * np.random.randn(n, 2) | ||
| xt[:n // 2, 1] += 2 | ||
|
|
||
|
|
||
| A = np.array([[1.5, .7], [.7, 1.5]]) | ||
| b = np.array([[4, 2]]) | ||
| xt = xt.dot(A) + b | ||
|
|
||
| ############################################################################## | ||
| # Plot data | ||
| # --------- | ||
|
|
||
| pl.figure(1, (5, 5)) | ||
| pl.plot(xs[:, 0], xs[:, 1], '+') | ||
| pl.plot(xt[:, 0], xt[:, 1], 'o') | ||
|
|
||
|
|
||
| ############################################################################## | ||
| # Estimate linear mapping and transport | ||
| # ------------------------------------- | ||
|
|
||
| Ae, be = ot.da.OT_mapping_linear(xs, xt) | ||
|
|
||
| xst = xs.dot(Ae) + be | ||
|
|
||
|
|
||
| ############################################################################## | ||
| # Plot transported samples | ||
| # ------------------------ | ||
|
|
||
| pl.figure(1, (5, 5)) | ||
| pl.clf() | ||
| pl.plot(xs[:, 0], xs[:, 1], '+') | ||
| pl.plot(xt[:, 0], xt[:, 1], 'o') | ||
| pl.plot(xst[:, 0], xst[:, 1], '+') | ||
|
|
||
| pl.show() | ||
|
|
||
| ############################################################################## | ||
| # Load image data | ||
| # --------------- | ||
|
|
||
|
|
||
| def im2mat(I): | ||
| """Converts and image to matrix (one pixel per line)""" | ||
| return I.reshape((I.shape[0] * I.shape[1], I.shape[2])) | ||
|
|
||
|
|
||
| def mat2im(X, shape): | ||
| """Converts back a matrix to an image""" | ||
| return X.reshape(shape) | ||
|
|
||
|
|
||
| def minmax(I): | ||
| return np.clip(I, 0, 1) | ||
|
|
||
|
|
||
| # Loading images | ||
| I1 = pl.imread('../data/ocean_day.jpg').astype(np.float64) / 256 | ||
| I2 = pl.imread('../data/ocean_sunset.jpg').astype(np.float64) / 256 | ||
|
|
||
|
|
||
| X1 = im2mat(I1) | ||
| X2 = im2mat(I2) | ||
|
|
||
| ############################################################################## | ||
| # Estimate mapping and adapt | ||
| # ---------------------------- | ||
|
|
||
| mapping = ot.da.LinearTransport() | ||
|
|
||
| mapping.fit(Xs=X1, Xt=X2) | ||
|
|
||
|
|
||
| xst = mapping.transform(Xs=X1) | ||
| xts = mapping.inverse_transform(Xt=X2) | ||
|
|
||
| I1t = minmax(mat2im(xst, I1.shape)) | ||
| I2t = minmax(mat2im(xts, I2.shape)) | ||
|
|
||
| # %% | ||
|
|
||
|
|
||
| ############################################################################## | ||
| # Plot transformed images | ||
| # ----------------------- | ||
|
|
||
| pl.figure(2, figsize=(10, 7)) | ||
|
|
||
| pl.subplot(2, 2, 1) | ||
| pl.imshow(I1) | ||
| pl.axis('off') | ||
| pl.title('Im. 1') | ||
|
|
||
| pl.subplot(2, 2, 2) | ||
| pl.imshow(I2) | ||
| pl.axis('off') | ||
| pl.title('Im. 2') | ||
|
|
||
| pl.subplot(2, 2, 3) | ||
| pl.imshow(I1t) | ||
| pl.axis('off') | ||
| pl.title('Mapping Im. 1') | ||
|
|
||
| pl.subplot(2, 2, 4) | ||
| pl.imshow(I2t) | ||
| pl.axis('off') | ||
| pl.title('Inverse mapping Im. 2') |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
regularized what ?