Skip to content

Commit 1804ac6

Browse files
[Feature] Add studio module (#1727)
1 parent 63ff6ec commit 1804ac6

13 files changed

Lines changed: 1452 additions & 40 deletions

File tree

modelscope/cli/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from modelscope.cli.scancache import ScanCacheCMD
1515
from modelscope.cli.server import ServerCMD
1616
from modelscope.cli.skills import SkillsCMD
17+
from modelscope.cli.studio import StudioCMD
1718
from modelscope.cli.upload import UploadCMD
1819
from modelscope.hub.constants import MODELSCOPE_ASCII
1920
from modelscope.utils.logger import get_logger
@@ -47,6 +48,7 @@ def run_cmd():
4748
LoginCMD.define_args(subparsers)
4849
LlamafileCMD.define_args(subparsers)
4950
ScanCacheCMD.define_args(subparsers)
51+
StudioCMD.define_args(subparsers)
5052

5153
args = parser.parse_args()
5254

modelscope/cli/create.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
Visibility, VisibilityMap)
88
from modelscope.hub.utils.aigc import AigcModel
99
from modelscope.hub.utils.utils import resolve_endpoint
10-
from modelscope.utils.constant import REPO_TYPE_MODEL, REPO_TYPE_SUPPORT
10+
from modelscope.utils.constant import (REPO_TYPE_MODEL, REPO_TYPE_STUDIO,
11+
REPO_TYPE_SUPPORT, StudioHardware,
12+
StudioSDKType)
1113
from modelscope.utils.logger import get_logger
1214

1315
logger = get_logger()
@@ -107,6 +109,35 @@ def define_args(parsers: _SubParsersAction):
107109
'then defaults to https://www.modelscope.cn.',
108110
)
109111

112+
# Studio specific arguments (only meaningful when --repo_type studio)
113+
studio_group = parser.add_argument_group(
114+
'Studio Repo Creation',
115+
'Optional arguments used only when `--repo_type studio` is set.')
116+
studio_group.add_argument(
117+
'--sdk-type',
118+
dest='sdk_type',
119+
choices=StudioSDKType.SUPPORTED,
120+
default=None,
121+
help='Studio SDK type (only for studio repo-type).')
122+
studio_group.add_argument(
123+
'--sdk-version',
124+
dest='sdk_version',
125+
type=str,
126+
default=None,
127+
help='Studio SDK version (only for gradio).')
128+
studio_group.add_argument(
129+
'--base-image',
130+
dest='base_image',
131+
type=str,
132+
default=None,
133+
help='Studio base image (only for gradio/streamlit).')
134+
studio_group.add_argument(
135+
'--hardware',
136+
dest='hardware',
137+
choices=StudioHardware.SUPPORTED,
138+
default=None,
139+
help='Studio hardware configuration.')
140+
110141
# AIGC specific arguments
111142
aigc_group = parser.add_argument_group(
112143
'AIGC Model Creation',
@@ -179,6 +210,14 @@ def _create_regular_repo(self):
179210
endpoint = resolve_endpoint(self.args.endpoint)
180211
api = HubApi(endpoint=endpoint)
181212

213+
extra_kwargs = {}
214+
if self.args.repo_type == REPO_TYPE_STUDIO:
215+
# Pass studio-specific fields only when creating a studio repo.
216+
for field in ('sdk_type', 'sdk_version', 'base_image', 'hardware'):
217+
value = getattr(self.args, field, None)
218+
if value is not None:
219+
extra_kwargs[field] = value
220+
182221
# Create repo
183222
api.create_repo(
184223
repo_id=self.args.repo_id,
@@ -191,6 +230,7 @@ def _create_regular_repo(self):
191230
create_default_config=True,
192231
endpoint=endpoint,
193232
gated_mode=self.args.gated_mode,
233+
**extra_kwargs,
194234
)
195235

196236
def _create_aigc_model(self):

modelscope/cli/download.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
from modelscope.hub.snapshot_download import (dataset_snapshot_download,
1212
snapshot_download)
1313
from modelscope.hub.utils.utils import convert_patterns, resolve_endpoint
14-
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
14+
from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
15+
REPO_TYPE_DATASET, REPO_TYPE_MODEL,
16+
REPO_TYPE_STUDIO, REPO_TYPE_SUPPORT)
1517
from modelscope.utils.logger import get_logger
1618

1719
logger = get_logger(log_level=logging.WARNING)
@@ -60,8 +62,8 @@ def define_args(parsers: ArgumentParser):
6062
)
6163
parser.add_argument(
6264
'--repo-type',
63-
choices=['model', 'dataset'],
64-
default='model',
65+
choices=REPO_TYPE_SUPPORT,
66+
default=REPO_TYPE_MODEL,
6567
help="Type of repo to download from (defaults to 'model').",
6668
)
6769
parser.add_argument(
@@ -135,9 +137,11 @@ def execute(self):
135137
self.args.files = [self.args.repo_id]
136138
else:
137139
if self.args.repo_id is not None:
138-
if self.args.repo_type == 'model':
140+
if self.args.repo_type in (REPO_TYPE_MODEL, REPO_TYPE_STUDIO):
141+
# studio repos share the same snapshot_download path
142+
# as model repos.
139143
self.args.model = self.args.repo_id
140-
elif self.args.repo_type == 'dataset':
144+
elif self.args.repo_type == REPO_TYPE_DATASET:
141145
self.args.dataset = self.args.repo_id
142146
else:
143147
raise Exception('Not support repo-type: %s'

0 commit comments

Comments
 (0)