diff --git a/diffsynth_engine/__init__.py b/diffsynth_engine/__init__.py index 7c282502..87896394 100644 --- a/diffsynth_engine/__init__.py +++ b/diffsynth_engine/__init__.py @@ -3,13 +3,14 @@ SDXLPipelineConfig, FluxPipelineConfig, WanPipelineConfig, + ControlNetParams, + ControlType, ) from .pipelines import ( FluxImagePipeline, SDXLImagePipeline, SDImagePipeline, WanVideoPipeline, - ControlNetParams, ) from .models.flux import FluxControlNet, FluxIPAdapter, FluxRedux from .models.sd import SDControlNet @@ -44,6 +45,7 @@ "FluxReplaceByControlTool", "FluxReduxRefTool", "ControlNetParams", + "ControlType", "fetch_model", "fetch_modelscope_model", "fetch_civitai_model", diff --git a/diffsynth_engine/configs/__init__.py b/diffsynth_engine/configs/__init__.py index ae1a7ca5..9e9dad3a 100644 --- a/diffsynth_engine/configs/__init__.py +++ b/diffsynth_engine/configs/__init__.py @@ -8,7 +8,7 @@ FluxPipelineConfig, WanPipelineConfig, ) -from .controlnet import ControlType +from .controlnet import ControlType, ControlNetParams __all__ = [ "BaseConfig", @@ -20,4 +20,5 @@ "FluxPipelineConfig", "WanPipelineConfig", "ControlType", + "ControlNetParams", ] diff --git a/diffsynth_engine/configs/controlnet.py b/diffsynth_engine/configs/controlnet.py index af120423..3fd13131 100644 --- a/diffsynth_engine/configs/controlnet.py +++ b/diffsynth_engine/configs/controlnet.py @@ -1,5 +1,13 @@ +from dataclasses import dataclass from enum import Enum +import torch +import torch.nn as nn +from typing import List, Union, Optional +from PIL import Image + +ImageType = Union[Image.Image, torch.Tensor, List[Image.Image], List[torch.Tensor]] + # FLUX ControlType class ControlType(Enum): @@ -15,3 +23,14 @@ def get_in_channel(self): return 128 elif self == ControlType.bfl_fill: return 384 + + +@dataclass +class ControlNetParams: + image: ImageType + scale: float = 1.0 + model: Optional[nn.Module] = None + mask: Optional[ImageType] = None + control_start: float = 0 + control_end: float = 1 + processor_name: Optional[str] = None # only used for sdxl controlnet union now diff --git a/diffsynth_engine/pipelines/__init__.py b/diffsynth_engine/pipelines/__init__.py index 615ecfe7..54c925ce 100644 --- a/diffsynth_engine/pipelines/__init__.py +++ b/diffsynth_engine/pipelines/__init__.py @@ -1,5 +1,4 @@ from .base import BasePipeline, LoRAStateDictConverter -from .controlnet_helper import ControlNetParams from .flux_image import FluxImagePipeline from .sdxl_image import SDXLImagePipeline from .sd_image import SDImagePipeline @@ -13,5 +12,4 @@ "SDXLImagePipeline", "SDImagePipeline", "WanVideoPipeline", - "ControlNetParams", ] diff --git a/diffsynth_engine/pipelines/controlnet_helper.py b/diffsynth_engine/pipelines/controlnet_helper.py deleted file mode 100644 index d219bcde..00000000 --- a/diffsynth_engine/pipelines/controlnet_helper.py +++ /dev/null @@ -1,26 +0,0 @@ -import torch -import torch.nn as nn -from typing import List, Union, Optional -from PIL import Image -from dataclasses import dataclass - -ImageType = Union[Image.Image, torch.Tensor, List[Image.Image], List[torch.Tensor]] - - -@dataclass -class ControlNetParams: - image: ImageType - scale: float = 1.0 - model: Optional[nn.Module] = None - mask: Optional[ImageType] = None - control_start: float = 0 - control_end: float = 1 - processor_name: Optional[str] = None # only used for sdxl controlnet union now - - -def accumulate(result, new_item): - if result is None: - return new_item - for i, item in enumerate(new_item): - result[i] += item - return result diff --git a/diffsynth_engine/pipelines/flux_image.py b/diffsynth_engine/pipelines/flux_image.py index 7f31ed5a..251b5ad3 100644 --- a/diffsynth_engine/pipelines/flux_image.py +++ b/diffsynth_engine/pipelines/flux_image.py @@ -17,10 +17,10 @@ flux_dit_config, flux_text_encoder_config, ) -from diffsynth_engine.configs import FluxPipelineConfig, ControlType +from diffsynth_engine.configs import FluxPipelineConfig, ControlType, ControlNetParams from diffsynth_engine.models.basic.lora import LoRAContext from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter -from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate +from diffsynth_engine.pipelines.utils import accumulate from diffsynth_engine.tokenizers import CLIPTokenizer, T5TokenizerFast from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler diff --git a/diffsynth_engine/pipelines/sd_image.py b/diffsynth_engine/pipelines/sd_image.py index be912787..09239713 100644 --- a/diffsynth_engine/pipelines/sd_image.py +++ b/diffsynth_engine/pipelines/sd_image.py @@ -6,12 +6,12 @@ from tqdm import tqdm from PIL import Image, ImageOps -from diffsynth_engine.configs import SDPipelineConfig +from diffsynth_engine.configs import SDPipelineConfig, ControlNetParams from diffsynth_engine.models.base import split_suffix from diffsynth_engine.models.basic.lora import LoRAContext from diffsynth_engine.models.sd import SDTextEncoder, SDVAEDecoder, SDVAEEncoder, SDUNet, sd_unet_config from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter -from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate +from diffsynth_engine.pipelines.utils import accumulate from diffsynth_engine.tokenizers import CLIPTokenizer from diffsynth_engine.algorithm.noise_scheduler import ScaledLinearScheduler from diffsynth_engine.algorithm.sampler import EulerSampler diff --git a/diffsynth_engine/pipelines/sdxl_image.py b/diffsynth_engine/pipelines/sdxl_image.py index 1b3bc6d3..6ca6ee87 100644 --- a/diffsynth_engine/pipelines/sdxl_image.py +++ b/diffsynth_engine/pipelines/sdxl_image.py @@ -6,7 +6,7 @@ from tqdm import tqdm from PIL import Image, ImageOps -from diffsynth_engine.configs import SDXLPipelineConfig +from diffsynth_engine.configs import SDXLPipelineConfig, ControlNetParams from diffsynth_engine.models.base import split_suffix from diffsynth_engine.models.basic.lora import LoRAContext from diffsynth_engine.models.basic.timestep import TemporalTimesteps @@ -19,7 +19,7 @@ sdxl_unet_config, ) from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter -from diffsynth_engine.pipelines.controlnet_helper import ControlNetParams, accumulate +from diffsynth_engine.pipelines.utils import accumulate from diffsynth_engine.tokenizers import CLIPTokenizer from diffsynth_engine.algorithm.noise_scheduler import ScaledLinearScheduler from diffsynth_engine.algorithm.sampler import EulerSampler diff --git a/diffsynth_engine/pipelines/utils.py b/diffsynth_engine/pipelines/utils.py new file mode 100644 index 00000000..d3356fc6 --- /dev/null +++ b/diffsynth_engine/pipelines/utils.py @@ -0,0 +1,6 @@ +def accumulate(result, new_item): + if result is None: + return new_item + for i, item in enumerate(new_item): + result[i] += item + return result diff --git a/tests/test_pipelines/test_flux_bfl_image.py b/tests/test_pipelines/test_flux_bfl_image.py index 8a48ee4b..4fbf74dd 100644 --- a/tests/test_pipelines/test_flux_bfl_image.py +++ b/tests/test_pipelines/test_flux_bfl_image.py @@ -1,9 +1,8 @@ import unittest from tests.common.test_case import ImageTestCase -from diffsynth_engine.configs import FluxPipelineConfig +from diffsynth_engine.configs import FluxPipelineConfig, ControlType, ControlNetParams from diffsynth_engine.pipelines import FluxImagePipeline -from diffsynth_engine.pipelines.flux_image import ControlType, ControlNetParams from diffsynth_engine.processor.canny_processor import CannyProcessor from diffsynth_engine.processor.depth_processor import DepthProcessor