Dev/d opsd#1122
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a new D-OPSD trainer implementation, including support for dual LoRA adapters, trajectory visualization, and updated configurations for the flux2_klein model. The review feedback identifies several critical and medium-severity issues, including missing arguments in latent unpacking, incorrect timestep scaling for the transformer, FSDP2 checkpointing failures for frozen parameters, and a potential AttributeError in the inference pipeline. All identified issues include actionable code suggestions to ensure stability and correctness.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def decode_packed_x0_to_images(self, packed_x0, latent_ids): | ||
| # height/width must be latent token grid sizes from img_ids, not pixel sizes. | ||
| unpatchified = Flux2KleinPipeline._unpack_latents_with_ids(packed_x0, latent_ids) | ||
| return self.decode_latent(unpatchified) |
There was a problem hiding this comment.
In decode_packed_x0_to_images, Flux2KleinPipeline._unpack_latents_with_ids is called without passing the height and width arguments. Since _unpack_latents_with_ids requires the latent token grid height and width to correctly unpack the latents, omitting them will cause a TypeError or incorrect unpacking. We can dynamically compute the latent grid height and width from latent_ids and pass them to the function.
@torch.no_grad()
def decode_packed_x0_to_images(self, packed_x0, latent_ids):
# height/width must be latent token grid sizes from img_ids, not pixel sizes.
latent_h = int(latent_ids[..., 1].max()) + 1
latent_w = int(latent_ids[..., 2].max()) + 1
unpatchified = Flux2KleinPipeline._unpack_latents_with_ids(
packed_x0, latent_ids, height=latent_h, width=latent_w
)
return self.decode_latent(unpatchified)| v_pred = self.transformer( | ||
| hidden_states=hidden_states, | ||
| timestep=timestep, | ||
| guidance=None, | ||
| encoder_hidden_states=condition["prompt_embed"], | ||
| txt_ids=condition["text_ids"], | ||
| img_ids=img_ids, | ||
| joint_attention_kwargs={}, | ||
| return_dict=False, | ||
| )[0] |
There was a problem hiding this comment.
In predict_velocity, the timestep argument passed to self.transformer is in the range [0, 1] (since it is divided by t_scale in the trainer). However, the Flux transformer expects the timestep to be scaled to [0, 1000]. Passing a value in [0, 1] will cause the transformer to interpret it as an extremely small timestep (near 0), leading to incorrect velocity predictions and poor generation quality. We should scale the timestep by 1000 before passing it to self.transformer.
| v_pred = self.transformer( | |
| hidden_states=hidden_states, | |
| timestep=timestep, | |
| guidance=None, | |
| encoder_hidden_states=condition["prompt_embed"], | |
| txt_ids=condition["text_ids"], | |
| img_ids=img_ids, | |
| joint_attention_kwargs={}, | |
| return_dict=False, | |
| )[0] | |
| v_pred = self.transformer( | |
| hidden_states=hidden_states, | |
| timestep=timestep * 1000, | |
| guidance=None, | |
| encoder_hidden_states=condition["prompt_embed"], | |
| txt_ids=condition["text_ids"], | |
| img_ids=img_ids, | |
| joint_attention_kwargs={}, | |
| return_dict=False, | |
| )[0] |
| def _load_resume_state(self, resume_ckpt_path): | ||
| if self.model.is_fsdp2_wrapped(): | ||
| self._load_distributed_state(resume_ckpt_path) | ||
| else: | ||
| self._load_single_process_state(resume_ckpt_path) | ||
|
|
||
| teacher_weights_path = self._teacher_lora_checkpoint_path(resume_ckpt_path) | ||
| if os.path.exists(teacher_weights_path): | ||
| self.model.load_lora_weights_for_resume( | ||
| resume_ckpt_path, | ||
| adapter_name=self.teacher_adapter, | ||
| weights_subdir="teacher", | ||
| ) | ||
| logger.info("Restored teacher EMA LoRA from {}", teacher_weights_path) | ||
| else: | ||
| self.model.copy_lora_adapter_weights(self.student_adapter, self.teacher_adapter) | ||
| logger.warning( | ||
| "Teacher LoRA not found in checkpoint {}; initialized teacher from student", | ||
| resume_ckpt_path, | ||
| ) |
There was a problem hiding this comment.
When FSDP2 is wrapped, calling load_lora_weights_for_resume directly on the sharded module will fail due to shape and key mismatches. Furthermore, because the teacher LoRA is frozen (requires_grad = False), it is excluded from the distributed checkpoint when ignore_frozen_params=True is used in dcp.save. To resolve this cleanly, we can temporarily set requires_grad = True for the teacher LoRA parameters during checkpoint saving and loading so that they are automatically saved and restored via dcp, and skip calling load_lora_weights_for_resume for the teacher adapter when FSDP2 is wrapped.
def _load_resume_state(self, resume_ckpt_path):
if self.model.is_fsdp2_wrapped():
# Temporarily set teacher LoRA requires_grad to True so it can be loaded by FSDP2 dcp
for name, param in self.model.denoiser_module().named_parameters():
if self.teacher_adapter in name and "lora" in name:
param.requires_grad = True
self._load_distributed_state(resume_ckpt_path)
# Restore teacher LoRA requires_grad to False
for name, param in self.model.denoiser_module().named_parameters():
if self.teacher_adapter in name and "lora" in name:
param.requires_grad = False
else:
self._load_single_process_state(resume_ckpt_path)
teacher_weights_path = self._teacher_lora_checkpoint_path(resume_ckpt_path)
if os.path.exists(teacher_weights_path):
self.model.load_lora_weights_for_resume(
resume_ckpt_path,
adapter_name=self.teacher_adapter,
weights_subdir="teacher",
)
logger.info("Restored teacher EMA LoRA from {}", teacher_weights_path)
else:
self.model.copy_lora_adapter_weights(self.student_adapter, self.teacher_adapter)
logger.warning(
"Teacher LoRA not found in checkpoint {}; initialized teacher from student",
resume_ckpt_path,
)| def _save_distributed_state(self, save_dir, iteration): | ||
| dist_state_path = os.path.join(save_dir, "dist_state") | ||
| if is_main_process(): | ||
| os.makedirs(dist_state_path, exist_ok=True) | ||
| torch.save( | ||
| { | ||
| "iteration": iteration, | ||
| "world_size": get_world_size(), | ||
| "lr_scheduler": self.lr_scheduler.state_dict(), | ||
| }, | ||
| os.path.join(save_dir, "trainer_state.pt"), | ||
| ) | ||
| barrier() | ||
|
|
||
| options = StateDictOptions(ignore_frozen_params=True, strict=False) | ||
| model_state, optim_state = get_state_dict(self.model.fsdp2_state_module(), self.optimizer, options=options) | ||
| dcp.save( | ||
| {"model": model_state, "optimizer": optim_state}, | ||
| checkpoint_id=dist_state_path, | ||
| ) |
There was a problem hiding this comment.
In FSDP2, because the teacher LoRA parameters are frozen (requires_grad = False), they are excluded from the distributed checkpoint when ignore_frozen_params=True is used in get_state_dict. To ensure the teacher LoRA weights are saved in dist_state, we should temporarily set requires_grad = True for the teacher LoRA parameters before calling get_state_dict, and then restore them to False afterwards.
def _save_distributed_state(self, save_dir, iteration):
dist_state_path = os.path.join(save_dir, "dist_state")
if is_main_process():
os.makedirs(dist_state_path, exist_ok=True)
torch.save(
{
"iteration": iteration,
"world_size": get_world_size(),
"lr_scheduler": self.lr_scheduler.state_dict(),
},
os.path.join(save_dir, "trainer_state.pt"),
)
barrier()
# Temporarily set teacher LoRA requires_grad to True so it is saved by FSDP2 dcp
for name, param in self.model.denoiser_module().named_parameters():
if self.teacher_adapter in name and "lora" in name:
param.requires_grad = True
options = StateDictOptions(ignore_frozen_params=True, strict=False)
model_state, optim_state = get_state_dict(self.model.fsdp2_state_module(), self.optimizer, options=options)
dcp.save(
{"model": model_state, "optimizer": optim_state},
checkpoint_id=dist_state_path,
)
# Restore teacher LoRA requires_grad to False
for name, param in self.model.denoiser_module().named_parameters():
if self.teacher_adapter in name and "lora" in name:
param.requires_grad = False|
|
||
| @torch.no_grad() | ||
| def _run_teacher_inference(self, current_iter, iter_output_dir): | ||
| dataset = self.inferencer.dataloader_eval.dataset |
There was a problem hiding this comment.
If dataloader_eval is not configured or is None, self.inferencer.dataloader_eval will be None, causing _run_teacher_inference to crash with an AttributeError when accessing .dataset. We should add a defensive check to safely return if the evaluation dataloader is not set.
| dataset = self.inferencer.dataloader_eval.dataset | |
| if self.inferencer.dataloader_eval is None: | |
| logger.warning("[train] Skipping teacher inference because dataloader_eval is not set.") | |
| return | |
| dataset = self.inferencer.dataloader_eval.dataset |
No description provided.