Skip to content

Dev/d opsd#1122

Open
wangshankun wants to merge 2 commits into
mainfrom
dev/d-opsd
Open

Dev/d opsd#1122
wangshankun wants to merge 2 commits into
mainfrom
dev/d-opsd

Conversation

@wangshankun
Copy link
Copy Markdown
Collaborator

No description provided.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new D-OPSD trainer implementation, including support for dual LoRA adapters, trajectory visualization, and updated configurations for the flux2_klein model. The review feedback identifies several critical and medium-severity issues, including missing arguments in latent unpacking, incorrect timestep scaling for the transformer, FSDP2 checkpointing failures for frozen parameters, and a potential AttributeError in the inference pipeline. All identified issues include actionable code suggestions to ensure stability and correctness.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +137 to +140
def decode_packed_x0_to_images(self, packed_x0, latent_ids):
# height/width must be latent token grid sizes from img_ids, not pixel sizes.
unpatchified = Flux2KleinPipeline._unpack_latents_with_ids(packed_x0, latent_ids)
return self.decode_latent(unpatchified)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

In decode_packed_x0_to_images, Flux2KleinPipeline._unpack_latents_with_ids is called without passing the height and width arguments. Since _unpack_latents_with_ids requires the latent token grid height and width to correctly unpack the latents, omitting them will cause a TypeError or incorrect unpacking. We can dynamically compute the latent grid height and width from latent_ids and pass them to the function.

    @torch.no_grad()
    def decode_packed_x0_to_images(self, packed_x0, latent_ids):
        # height/width must be latent token grid sizes from img_ids, not pixel sizes.
        latent_h = int(latent_ids[..., 1].max()) + 1
        latent_w = int(latent_ids[..., 2].max()) + 1
        unpatchified = Flux2KleinPipeline._unpack_latents_with_ids(
            packed_x0, latent_ids, height=latent_h, width=latent_w
        )
        return self.decode_latent(unpatchified)

Comment on lines +196 to +205
v_pred = self.transformer(
hidden_states=hidden_states,
timestep=timestep,
guidance=None,
encoder_hidden_states=condition["prompt_embed"],
txt_ids=condition["text_ids"],
img_ids=img_ids,
joint_attention_kwargs={},
return_dict=False,
)[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

In predict_velocity, the timestep argument passed to self.transformer is in the range [0, 1] (since it is divided by t_scale in the trainer). However, the Flux transformer expects the timestep to be scaled to [0, 1000]. Passing a value in [0, 1] will cause the transformer to interpret it as an extremely small timestep (near 0), leading to incorrect velocity predictions and poor generation quality. We should scale the timestep by 1000 before passing it to self.transformer.

Suggested change
v_pred = self.transformer(
hidden_states=hidden_states,
timestep=timestep,
guidance=None,
encoder_hidden_states=condition["prompt_embed"],
txt_ids=condition["text_ids"],
img_ids=img_ids,
joint_attention_kwargs={},
return_dict=False,
)[0]
v_pred = self.transformer(
hidden_states=hidden_states,
timestep=timestep * 1000,
guidance=None,
encoder_hidden_states=condition["prompt_embed"],
txt_ids=condition["text_ids"],
img_ids=img_ids,
joint_attention_kwargs={},
return_dict=False,
)[0]

Comment on lines +137 to +156
def _load_resume_state(self, resume_ckpt_path):
if self.model.is_fsdp2_wrapped():
self._load_distributed_state(resume_ckpt_path)
else:
self._load_single_process_state(resume_ckpt_path)

teacher_weights_path = self._teacher_lora_checkpoint_path(resume_ckpt_path)
if os.path.exists(teacher_weights_path):
self.model.load_lora_weights_for_resume(
resume_ckpt_path,
adapter_name=self.teacher_adapter,
weights_subdir="teacher",
)
logger.info("Restored teacher EMA LoRA from {}", teacher_weights_path)
else:
self.model.copy_lora_adapter_weights(self.student_adapter, self.teacher_adapter)
logger.warning(
"Teacher LoRA not found in checkpoint {}; initialized teacher from student",
resume_ckpt_path,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When FSDP2 is wrapped, calling load_lora_weights_for_resume directly on the sharded module will fail due to shape and key mismatches. Furthermore, because the teacher LoRA is frozen (requires_grad = False), it is excluded from the distributed checkpoint when ignore_frozen_params=True is used in dcp.save. To resolve this cleanly, we can temporarily set requires_grad = True for the teacher LoRA parameters during checkpoint saving and loading so that they are automatically saved and restored via dcp, and skip calling load_lora_weights_for_resume for the teacher adapter when FSDP2 is wrapped.

    def _load_resume_state(self, resume_ckpt_path):
        if self.model.is_fsdp2_wrapped():
            # Temporarily set teacher LoRA requires_grad to True so it can be loaded by FSDP2 dcp
            for name, param in self.model.denoiser_module().named_parameters():
                if self.teacher_adapter in name and "lora" in name:
                    param.requires_grad = True
            self._load_distributed_state(resume_ckpt_path)
            # Restore teacher LoRA requires_grad to False
            for name, param in self.model.denoiser_module().named_parameters():
                if self.teacher_adapter in name and "lora" in name:
                    param.requires_grad = False
        else:
            self._load_single_process_state(resume_ckpt_path)
            teacher_weights_path = self._teacher_lora_checkpoint_path(resume_ckpt_path)
            if os.path.exists(teacher_weights_path):
                self.model.load_lora_weights_for_resume(
                    resume_ckpt_path,
                    adapter_name=self.teacher_adapter,
                    weights_subdir="teacher",
                )
                logger.info("Restored teacher EMA LoRA from {}", teacher_weights_path)
            else:
                self.model.copy_lora_adapter_weights(self.student_adapter, self.teacher_adapter)
                logger.warning(
                    "Teacher LoRA not found in checkpoint {}; initialized teacher from student",
                    resume_ckpt_path,
                )

Comment on lines +631 to +650
def _save_distributed_state(self, save_dir, iteration):
dist_state_path = os.path.join(save_dir, "dist_state")
if is_main_process():
os.makedirs(dist_state_path, exist_ok=True)
torch.save(
{
"iteration": iteration,
"world_size": get_world_size(),
"lr_scheduler": self.lr_scheduler.state_dict(),
},
os.path.join(save_dir, "trainer_state.pt"),
)
barrier()

options = StateDictOptions(ignore_frozen_params=True, strict=False)
model_state, optim_state = get_state_dict(self.model.fsdp2_state_module(), self.optimizer, options=options)
dcp.save(
{"model": model_state, "optimizer": optim_state},
checkpoint_id=dist_state_path,
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In FSDP2, because the teacher LoRA parameters are frozen (requires_grad = False), they are excluded from the distributed checkpoint when ignore_frozen_params=True is used in get_state_dict. To ensure the teacher LoRA weights are saved in dist_state, we should temporarily set requires_grad = True for the teacher LoRA parameters before calling get_state_dict, and then restore them to False afterwards.

    def _save_distributed_state(self, save_dir, iteration):
        dist_state_path = os.path.join(save_dir, "dist_state")
        if is_main_process():
            os.makedirs(dist_state_path, exist_ok=True)
            torch.save(
                {
                    "iteration": iteration,
                    "world_size": get_world_size(),
                    "lr_scheduler": self.lr_scheduler.state_dict(),
                },
                os.path.join(save_dir, "trainer_state.pt"),
            )
        barrier()

        # Temporarily set teacher LoRA requires_grad to True so it is saved by FSDP2 dcp
        for name, param in self.model.denoiser_module().named_parameters():
            if self.teacher_adapter in name and "lora" in name:
                param.requires_grad = True

        options = StateDictOptions(ignore_frozen_params=True, strict=False)
        model_state, optim_state = get_state_dict(self.model.fsdp2_state_module(), self.optimizer, options=options)
        dcp.save(
            {"model": model_state, "optimizer": optim_state},
            checkpoint_id=dist_state_path,
        )

        # Restore teacher LoRA requires_grad to False
        for name, param in self.model.denoiser_module().named_parameters():
            if self.teacher_adapter in name and "lora" in name:
                param.requires_grad = False


@torch.no_grad()
def _run_teacher_inference(self, current_iter, iter_output_dir):
dataset = self.inferencer.dataloader_eval.dataset
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If dataloader_eval is not configured or is None, self.inferencer.dataloader_eval will be None, causing _run_teacher_inference to crash with an AttributeError when accessing .dataset. We should add a defensive check to safely return if the evaluation dataloader is not set.

Suggested change
dataset = self.inferencer.dataloader_eval.dataset
if self.inferencer.dataloader_eval is None:
logger.warning("[train] Skipping teacher inference because dataloader_eval is not set.")
return
dataset = self.inferencer.dataloader_eval.dataset

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant