-
Notifications
You must be signed in to change notification settings - Fork 243
[2/4] Diffusion Quantized ckpt export #810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the
📝 WalkthroughWalkthroughThis PR adds comprehensive support for quantizing LTX-2 video models within a diffusion-based quantization framework. Changes include new model registration and configuration, multi-stage pipeline creation with dynamic class lookup, LTX-2-specific quantization handlers, duck-typed pipeline export paths, and infrastructure updates to support both traditional diffusers and LTX-2 pipelines alongside module forward caching for preservation of pre-quantization methods. Changes
Sequence Diagram(s)sequenceDiagram
participant User
participant Main as main()
participant PM as PipelineManager
participant Cal as Calibrator
participant Quan as Quantizer
participant EM as ExportManager
User->>Main: invoke with LTX-2 config + extra_params
Main->>Main: parse_extra_params() from CLI
Main->>Main: create ModelConfig with extra_params
Main->>PM: __init__(config)
PM->>PM: create_pipeline() with extra_params
PM->>PM: _create_ltx2_pipeline() for LTX-2
PM->>PM: _ensure_ltx2_transformer_cached()
PM-->>Main: pipeline ready
Main->>PM: get_backbone()
PM-->>Main: cached LTX-2 transformer
Main->>Cal: run_calibration()
Cal->>Cal: _run_ltx2_calibration() dispatch
Cal->>PM: forward pipeline with LTX-2 prompts
Cal-->>Main: calibration complete
Main->>Quan: quantize_model(backbone, quant_config)
Quan->>Quan: register_ltx2_quant_linear()
Quan->>Quan: apply quantization with FP8 upcast
Quan-->>Main: quantized backbone
Main->>EM: export_hf_ckpt(pipeline)
EM->>EM: generate_diffusion_dummy_forward_fn()
EM->>EM: get_diffusion_components() with duck-typing
EM->>EM: _export_diffusers_checkpoint() with Any type
EM-->>Main: exported checkpoint
sequenceDiagram
participant PM as PipelineManager
participant DM as DynamicModule
participant QM as QuantInputBase
participant DL as _QuantLTX2Linear
PM->>PM: create_pipeline() calls DiffusionPipeline
PM->>DM: convert() wraps forward methods
DM->>DM: bind_forward_method_if_needed()
DM->>DM: cache original in _forward_pre_dm
DM-->>PM: pipeline with cached forwards
PM->>QM: forward() during inference
QM->>QM: check _forward_pre_dm exists
QM->>QM: invoke _forward_pre_dm() if cached
QM->>DL: _get_quantized_weight() override
DL->>DL: upcast FP8 to bfloat16 if needed
DL-->>QM: quantized weight
QM->>QM: apply output quantization
QM-->>PM: quantized output
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #810 +/- ##
==========================================
+ Coverage 73.31% 73.33% +0.01%
==========================================
Files 192 192
Lines 19613 19631 +18
==========================================
+ Hits 14380 14396 +16
- Misses 5233 5235 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| else: | ||
| cpu_state_dict = { | ||
| k: v.detach().contiguous().cpu() for k, v in component.state_dict().items() | ||
| } | ||
| save_file(cpu_state_dict, str(component_export_dir / "model.safetensors")) | ||
| with open(component_export_dir / "config.json", "w") as f: | ||
| json.dump( | ||
| { | ||
| "_class_name": type(component).__name__, | ||
| "_export_format": "safetensors_state_dict", | ||
| }, | ||
| f, | ||
| indent=4, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we combine these with L851 to L863? They look duplicated.
Why we need to offload tensors to cpu before saving?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we always save with safetensors, keeping the .cpu() is the safe/default choice. this is also how the transformers/diffusers save_pretrained save the tensors to safetensors file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you clarify more?
Can we combine these with L851 to L863? They look duplicated.
Line 880 saves the state dict to safe tensor, line 884 saves the quant config to config.json. we use these 2 function only if the model is not diffusers based.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cpu_state_dict = {
k: v.detach().contiguous().cpu() for k, v in component.state_dict().items()
}
save_file(cpu_state_dict, str(component_export_dir / "model.safetensors"))
with open(component_export_dir / "config.json", "w") as f:
json.dump(
{
"_class_name": type(component).__name__,
"_export_format": "safetensors_state_dict",
},
f,
indent=4,
)
I mean this code block appears twice in the same script.
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
Signed-off-by: Jingyu Xin <[email protected]>
ef4f814 to
9f0e998
Compare
Signed-off-by: Jingyu Xin <[email protected]>
| # overridden it by binding the dynamic forward onto the instance (to follow the MRO). | ||
| # On final export, restore the original forward to avoid leaking a dynamic forward | ||
| # (e.g., DistillationModel.forward) onto the exported (non-dynamic) module instance. | ||
| if hasattr(self, "_forward_pre_dm"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # accelerate patched module | ||
| bind_forward_method(self, self.__class__.forward) | ||
| else: | ||
| if not hasattr(self, "_forward_pre_dm"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """Quantize the input before calling the original forward method.""" | ||
| input = self.input_quantizer(input) | ||
| output = super().forward(input, *args, **kwargs) | ||
| if hasattr(self, "_forward_pre_dm"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recommend to submit a dedicated PR for changes to dynamic module.
ChenhanYu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Commented on the dynamic module part.
Edwardf0t1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, left a few more comments.
What does this PR do?
Type of change: New feature
Overview:
This MR adds HuggingFace checkpoint export support for LTX‑2 by treating TI2VidTwoStagesPipeline as a diffusion-like pipeline, exporting only the stage‑1 transformer (with QKV-fusion-enabled dummy inputs) and falling back to writing model.safetensors when save_pretrained isn’t available. It also preserves the original forward in DynamicModule patching (_forward_pre_dm) so downstream callers can still invoke the pre-patched forward implementation.
Changes
DynamicModulepatching: when patching forward, we now stash the pre-patched implementation inself._forward_pre_dm(once) so downstream code can still call the original forward, then re-bind forward to the class implementation. This is needed for the LTX2 FP8 calibration.export_hf_checkpoint()now also treats ltx_pipelines.ti2vid_two_stages.TI2VidTwoStagesPipeline as a “diffusion-like” object and routes it through _export_diffusers_checkpoint() (import guarded; no hard dependency).Plans
Usage
Testing
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
Release Notes
New Features
--extra-paramCLI option for flexible model configuration and parameter passingChores
✏️ Tip: You can customize this high-level summary in your review settings.