Skip to content

feat(wan): Add prior-based diffusion step skip for ~70% fewer inference steps#1325

Open
Efrat-Taig wants to merge 2 commits intomodelscope:mainfrom
Efrat-Taig:feature/prior-based-step-skip
Open

feat(wan): Add prior-based diffusion step skip for ~70% fewer inference steps#1325
Efrat-Taig wants to merge 2 commits intomodelscope:mainfrom
Efrat-Taig:feature/prior-based-step-skip

Conversation

@Efrat-Taig
Copy link

Summary

Adds prior-based diffusion step skip for Wan video models — ~70% fewer inference steps with same quality, zero retraining.

Changes

  • Pipeline (wan_video.py): Added step_callback, prior_latents, prior_timesteps, prior_sigmas, start_from_step for resuming from saved latents
  • Example scripts (examples/wanvideo/prior_based_step_skip/): generate_prior.py, infer_from_prior.py, prior_utils.py, README.md
  • Docs: Added Prior-based step skip section to Wan docs (en + zh)

Use case

When identity/scene is fixed and only motion varies (e.g. lip-sync, different actions), early diffusion steps are largely redundant.

Quick start

python examples/wanvideo/prior_based_step_skip/generate_prior.py --download_example --output_dir ./prior_output --num_inference_steps 10
python examples/wanvideo/prior_based_step_skip/infer_from_prior.py --prior_dir ./prior_output/run_<id> --start_step 6 --image data/examples/wan/input_image.jpg

Made with Cursor

…ce steps

- Add step_callback, prior_latents, prior_timesteps, prior_sigmas, start_from_step
  to WanVideoPipeline for resuming from saved latents
- Add examples/wanvideo/prior_based_step_skip/ with generate_prior.py,
  infer_from_prior.py, prior_utils.py, and README
- Add --download_example to generate_prior.py for easy onboarding
- Add Prior-based step skip section to Wan docs (en + zh)
- Supports fixed identity/scene with varying motion (e.g. lip-sync, different actions)

Made-with: Cursor
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant optimization for Wan video models by enabling a prior-based diffusion step skip. This feature allows users to drastically reduce the number of inference steps required for video generation, particularly when the identity or scene remains constant and only motion varies. By saving and reusing intermediate latent states, the system can skip redundant early diffusion steps, leading to faster generation times without compromising video quality.

Highlights

  • Prior-Based Diffusion Step Skip: Implemented a prior-based diffusion step skip mechanism for Wan video models, allowing for significantly fewer inference steps (approximately 70% reduction) while maintaining the same output quality, without requiring any model retraining.
  • Pipeline Enhancements: Modified the WanVideoPipeline to accept new parameters such as step_callback, prior_latents, prior_timesteps, prior_sigmas, and start_from_step, enabling the pipeline to save intermediate latents and resume inference from a specific step.
  • Example Scripts and Documentation: Introduced new example scripts (generate_prior.py, infer_from_prior.py, prior_utils.py) to demonstrate the usage of the prior-based step skip feature, along with updated documentation in both English and Chinese.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • diffsynth/pipelines/wan_video.py
    • Added Callable import to support new callback functionality.
    • Introduced new parameters (step_callback, prior_latents, prior_timesteps, prior_sigmas, start_from_step) to the __call__ method for managing prior-based inference.
    • Implemented logic to override scheduler timesteps and sigmas when resuming from prior latents.
    • Modified the denoising loop to allow starting from a specified step and to inject prior latents.
    • Integrated an optional step_callback to be invoked after each denoising step, facilitating the saving of intermediate latents.
  • docs/en/Model_Details/Wan.md
    • Added a new section titled 'Prior-Based Step Skip' explaining the concept and benefits of the feature.
  • docs/zh/Model_Details/Wan.md
    • Added a new section titled '基于先验的步长跳过' (Prior-Based Step Skip) explaining the concept and benefits in Chinese.
  • examples/wanvideo/prior_based_step_skip/README.md
    • Added a new README file detailing the prior-based diffusion step skip feature, including quick start instructions, an explanation of how it works, and descriptions of the associated scripts and options.
  • examples/wanvideo/prior_based_step_skip/generate_prior.py
    • Added a new Python script to perform full inference and save latent tensors at each denoising step, creating a 'prior' for later accelerated inference.
  • examples/wanvideo/prior_based_step_skip/infer_from_prior.py
    • Added a new Python script to load a saved prior latent and resume inference from a specified step, enabling accelerated video generation.
  • examples/wanvideo/prior_based_step_skip/prior_utils.py
    • Added a new utility Python script containing functions to build step callbacks, save and load run metadata, and validate scheduler parameters for compatibility between prior generation and inference.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a prior-based diffusion step skipping feature to accelerate inference for Wan video models. The changes include modifications to the WanVideoPipeline to support resuming from a saved latent state, along with new example scripts and documentation. My review focuses on ensuring the correctness of the new logic and improving the maintainability of the example code. I've identified a critical issue with an inconsistent condition that could lead to incorrect behavior, and a couple of medium-severity issues related to code duplication and style in the new example files.

Comment on lines +301 to +302
if prior_latents is not None and start_from_step is not None:
inputs_shared["latents"] = prior_latents.to(dtype=self.torch_dtype, device=self.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The condition for replacing latents with the prior is missing a check for prior_timesteps. The check at line 262 correctly requires prior_latents, prior_timesteps, and start_from_step. If prior_timesteps is not provided here, the scheduler will use incorrect timesteps with the loaded prior latents, which can lead to incorrect generation results. To ensure consistency and prevent bugs, the condition should be the same as the one at line 262.

Suggested change
if prior_latents is not None and start_from_step is not None:
inputs_shared["latents"] = prior_latents.to(dtype=self.torch_dtype, device=self.device)
if prior_latents is not None and prior_timesteps is not None and start_from_step is not None:
inputs_shared["latents"] = prior_latents.to(dtype=self.torch_dtype, device=self.device)

Comment on lines +107 to +166
if args.model == "I2V-480P":
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(
model_id="Wan-AI/Wan2.1-I2V-14B-480P",
origin_file_pattern="diffusion_pytorch_model*.safetensors",
**vram_config,
),
ModelConfig(
model_id="Wan-AI/Wan2.1-I2V-14B-480P",
origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth",
**vram_config,
),
ModelConfig(
model_id="Wan-AI/Wan2.1-I2V-14B-480P",
origin_file_pattern="Wan2.1_VAE.pth",
**vram_config,
),
ModelConfig(
model_id="Wan-AI/Wan2.1-I2V-14B-480P",
origin_file_pattern="models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
**vram_config,
),
],
tokenizer_config=ModelConfig(
model_id="Wan-AI/Wan2.1-T2V-1.3B",
origin_file_pattern="google/umt5-xxl/",
),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024**3) - 2,
)
else:
# T2V-1.3B (no image)
pipe = WanVideoPipeline.from_pretrained(
torch_dtype=torch.bfloat16,
device="cuda",
model_configs=[
ModelConfig(
model_id="Wan-AI/Wan2.1-T2V-1.3B",
origin_file_pattern="diffusion_pytorch_model*.safetensors",
**vram_config,
),
ModelConfig(
model_id="Wan-AI/Wan2.1-T2V-1.3B",
origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth",
**vram_config,
),
ModelConfig(
model_id="Wan-AI/Wan2.1-T2V-1.3B",
origin_file_pattern="Wan2.1_VAE.pth",
**vram_config,
),
],
tokenizer_config=ModelConfig(
model_id="Wan-AI/Wan2.1-T2V-1.3B",
origin_file_pattern="google/umt5-xxl/",
),
vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024**3) - 2,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The model loading logic for I2V-480P and T2V-1.3B is largely duplicated. This can be refactored to improve maintainability and readability. By extracting the common patterns and parameterizing the differences, you can make the code more concise and easier to extend with new models in the future.

    model_file_patterns = {
        "I2V-480P": [
            "diffusion_pytorch_model*.safetensors",
            "models_t5_umt5-xxl-enc-bf16.pth",
            "Wan2.1_VAE.pth",
            "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth",
        ],
        "T2V-1.3B": [
            "diffusion_pytorch_model*.safetensors",
            "models_t5_umt5-xxl-enc-bf16.pth",
            "Wan2.1_VAE.pth",
        ],
    }
    model_ids = {
        "I2V-480P": "Wan-AI/Wan2.1-I2V-14B-480P",
        "T2V-1.3B": "Wan-AI/Wan2.1-T2V-1.3B",
    }
    model_id = model_ids[args.model]
    file_patterns = model_file_patterns[args.model]
    model_configs = [
        ModelConfig(
            model_id=model_id,
            origin_file_pattern=pattern,
            **vram_config,
        )
        for pattern in file_patterns
    ]

    pipe = WanVideoPipeline.from_pretrained(
        torch_dtype=torch.bfloat16,
        device="cuda",
        model_configs=model_configs,
        tokenizer_config=ModelConfig(
            model_id="Wan-AI/Wan2.1-T2V-1.3B",
            origin_file_pattern="google/umt5-xxl/",
        ),
        vram_limit=torch.cuda.mem_get_info("cuda")[1] / (1024**3) - 2,
    )

Comment on lines +137 to +138
import imageio
import numpy as np
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to the PEP 8 style guide, imports should be at the top of the file. This makes dependencies clear and avoids potential issues with repeated imports. Please move import imageio and import numpy as np to the top of the file.

References
  1. PEP 8: E402 module level import not at top of file (link)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant