Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions modelscope/cli/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from modelscope.cli.base import CLICommand
from modelscope.hub.api import HubApi
from modelscope.hub.constants import (Licenses, ModelVisibility, Visibility,
VisibilityMap)
from modelscope.hub.constants import (GatedMode, Licenses, ModelVisibility,
Visibility, VisibilityMap)
from modelscope.hub.utils.aigc import AigcModel
from modelscope.hub.utils.utils import resolve_endpoint
from modelscope.utils.constant import REPO_TYPE_MODEL, REPO_TYPE_SUPPORT
Expand Down Expand Up @@ -82,6 +82,20 @@ def define_args(parsers: _SubParsersAction):
help=
'If True, do not raise error when repo already exists. Defaults to False.',
)
parser.add_argument(
'--gated',
dest='gated_mode',
action='store_true',
default=None,
help=
'Enable gated mode (application-based download) for private repos.',
)
parser.add_argument(
'--no-gated',
dest='gated_mode',
action='store_false',
help='Disable gated mode for private repos (normal private).',
)
parser.add_argument(
'--endpoint',
type=str,
Expand Down Expand Up @@ -176,6 +190,7 @@ def _create_regular_repo(self):
exist_ok=self.args.exist_ok,
create_default_config=True,
endpoint=endpoint,
gated_mode=self.args.gated_mode,
)

def _create_aigc_model(self):
Expand Down Expand Up @@ -225,7 +240,8 @@ def _create_aigc_model(self):
visibility=visibility_idx,
license=self.args.license,
chinese_name=self.args.chinese_name,
aigc_model=aigc_model)
aigc_model=aigc_model,
gated_mode=self.args.gated_mode)
print(f'Successfully created AIGC model: {model_url}')
except Exception as e:
print(f'Error creating AIGC model: {e}')
48 changes: 43 additions & 5 deletions modelscope/hub/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,8 @@ def create_model(self,
original_model_id: Optional[str] = '',
endpoint: Optional[str] = None,
token: Optional[str] = None,
aigc_model: Optional['AigcModel'] = None) -> str:
aigc_model: Optional['AigcModel'] = None,
gated_mode: Optional[bool] = None) -> str:
"""Create model repo at ModelScope Hub.

Args:
Expand All @@ -343,6 +344,9 @@ def create_model(self,
aigc_model (AigcModel, optional): AigcModel instance for AIGC model creation.
If provided, will create an AIGC model with automatic file upload.
Refer to modelscope.hub.utils.aigc.AigcModel for details.
gated_mode (bool, optional): Gated mode for private repos.
True = gated (application-based download), False = off (normal private).
Only effective when visibility is PRIVATE (1).

Returns:
str: URL of the created model repository
Expand Down Expand Up @@ -374,6 +378,12 @@ def create_model(self,
'TrainId': os.environ.get('MODELSCOPE_TRAIN_ID', '')
}

if gated_mode is not None:
if visibility != ModelVisibility.PRIVATE:
logger.warning('gated_mode is only effective when visibility is PRIVATE, ignored.')
else:
body['ProtectedMode'] = 1 if gated_mode else 2

# Set path based on model type
if aigc_model is not None:
# Use AIGC model endpoint
Expand Down Expand Up @@ -1428,7 +1438,8 @@ def create_dataset(self,
visibility: Optional[int] = DatasetVisibility.PUBLIC,
description: Optional[str] = '',
endpoint: Optional[str] = None,
token: Optional[str] = None) -> str:
token: Optional[str] = None,
gated_mode: Optional[bool] = None) -> str:
"""
Create a dataset in ModelScope.

Expand All @@ -1441,6 +1452,9 @@ def create_dataset(self,
description (str, optional): The description of the dataset. Defaults to ''.
endpoint (str, optional): The endpoint to use. If not provided, the default endpoint is used.
token (str, optional): The access token for authentication.
gated_mode (bool, optional): Gated mode for private repos.
True = gated (application-based download), False = off (normal private).
Only effective when visibility is PRIVATE (1).

Returns:
str: The URL of the created dataset repository.
Expand All @@ -1462,6 +1476,12 @@ def create_dataset(self,
'Description': (None, description)
}

if gated_mode is not None:
if visibility != DatasetVisibility.PRIVATE:
logger.warning('gated_mode is only effective when visibility is PRIVATE, ignored.')
else:
files['ProtectedMode'] = (None, 1 if gated_mode else 2)

r = self.session.post(
path,
files=files,
Expand Down Expand Up @@ -2200,6 +2220,7 @@ def create_repo(
exist_ok: Optional[bool] = False,
create_default_config: Optional[bool] = True,
aigc_model: Optional[AigcModel] = None,
gated_mode: Optional[bool] = None,
**kwargs,
) -> str:
"""
Expand All @@ -2217,6 +2238,9 @@ def create_repo(
In the format of `https://www.modelscope.cn` or 'https://www.modelscope.ai'
exist_ok (Optional[bool]): If the repo exists, whether to return the repo url directly.
create_default_config (Optional[bool]): If True, create a default configuration file in the model repo.
gated_mode (Optional[bool]): Gated mode for private repos.
True = gated (application-based download), False = off (normal private).
Only effective when visibility is ``private``.
**kwargs: The additional arguments.

Returns:
Expand Down Expand Up @@ -2256,6 +2280,7 @@ def create_repo(
aigc_model=aigc_model,
token=token,
endpoint=endpoint,
gated_mode=gated_mode,
)
if create_default_config:
with tempfile.TemporaryDirectory() as temp_cache_dir:
Expand Down Expand Up @@ -2290,6 +2315,7 @@ def create_repo(
visibility=visibility,
token=token,
endpoint=endpoint,
gated_mode=gated_mode,
)
print(f'New dataset created successfully at {repo_url}.', flush=True)

Expand Down Expand Up @@ -3882,7 +3908,8 @@ def set_repo_visibility(self,
repo_id: str,
repo_type: Literal['model', 'dataset'],
visibility: Literal['private', 'public'],
token: Union[str, None] = None
token: Union[str, None] = None,
gated_mode: Optional[bool] = None,
) -> dict:
"""
Set the visibility of a repo.
Expand All @@ -3893,6 +3920,9 @@ def set_repo_visibility(self,
visibility (Literal['private', 'public']): The visibility to set, `private` or `public`.
token (Union[str, None]): The access token. If None, will use the cookies from the local cache.
See `https://modelscope.cn/my/myaccesstoken` to get your token.
gated_mode (Optional[bool]): Gated mode for private repos.
True = gated (application-based download), False = off (normal private).
Only effective when visibility is ``private``.

Returns:
dict: The response from the server.
Expand All @@ -3907,6 +3937,10 @@ def set_repo_visibility(self,
visibility_code: int = visibility_map.get(visibility, 5)
cookies = self.get_cookies(access_token=token, cookies_required=True)

if gated_mode is not None and visibility != 'private':
logger.warning('gated_mode is only effective when visibility is private, ignored.')
gated_mode = None

if repo_type == REPO_TYPE_MODEL:
model_info = self.get_model(model_id=repo_id, token=token)
path = f'{self.endpoint}/api/v1/models/{repo_id}'
Expand All @@ -3916,11 +3950,15 @@ def set_repo_visibility(self,
first = tasks[0]
if isinstance(first, dict) and first:
model_tasks = first.get('name')
if gated_mode is not None:
pm = 1 if gated_mode else 2
else:
pm = model_info.get('ProtectedMode', 2)
payload = {
'ChineseName': model_info.get('ChineseName', ''),
'ModelFramework': model_info.get('ModelFramework', 'Pytorch'),
'Visibility': visibility_code,
'ProtectedMode': 2,
'ProtectedMode': pm,
'ApprovalMode': model_info.get('ApprovalMode', 2),
'Description': model_info.get('Description', ''),
'AigcType': model_info.get('AigcType', ''),
Comment thread
wangxingjun778 marked this conversation as resolved.
Expand All @@ -3947,7 +3985,7 @@ def set_repo_visibility(self,
path = f'{self.endpoint}/api/v1/datasets/{dataset_idx}'
payload = {
'Visibility': visibility_code,
'ProtectedMode': 2,
'ProtectedMode': (1 if gated_mode else 2) if gated_mode is not None else 2,
}
Comment thread
wangxingjun778 marked this conversation as resolved.
else:
raise ValueError(f'Invalid repo type: {repo_type}, supported repos: {REPO_TYPE_SUPPORT}')
Expand Down
12 changes: 12 additions & 0 deletions modelscope/hub/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,18 @@ class Visibility(object):
PUBLIC = 'public'


class GatedMode(object):
"""Gated mode for private repositories.

Only effective when Visibility is PRIVATE.
API payload key: ``ProtectedMode``.
Values: True = gated (application-based download),
False = off (normal private).
"""
GATED = True
OFF = False


VisibilityMap = {
ModelVisibility.PRIVATE: Visibility.PRIVATE,
ModelVisibility.INTERNAL: Visibility.INTERNAL,
Expand Down
Loading