diff --git a/diffsynth_engine/conf/models/flux/flux_dit.json b/diffsynth_engine/conf/models/flux/flux_dit.json index e6739364..b275e336 100644 --- a/diffsynth_engine/conf/models/flux/flux_dit.json +++ b/diffsynth_engine/conf/models/flux/flux_dit.json @@ -101,5 +101,24 @@ "proj_mlp": "proj_in_besides_attn", "proj_out": "proj_out" } - } + }, + "preferred_kontext_resolutions": [ + [672, 1568], + [688, 1504], + [720, 1456], + [752, 1392], + [800, 1328], + [832, 1248], + [880, 1184], + [944, 1104], + [1024, 1024], + [1104, 944], + [1184, 880], + [1248, 832], + [1328, 800], + [1392, 752], + [1456, 720], + [1504, 688], + [1568, 672] + ] } \ No newline at end of file diff --git a/diffsynth_engine/conf/models/flux/flux_vae.json b/diffsynth_engine/conf/models/flux/flux_vae.json index 10b7e503..f0afbebb 100644 --- a/diffsynth_engine/conf/models/flux/flux_vae.json +++ b/diffsynth_engine/conf/models/flux/flux_vae.json @@ -5,6 +5,8 @@ "decoder.conv_in.weight": "decoder.conv_in.weight", "decoder.conv_out.bias": "decoder.conv_out.bias", "decoder.conv_out.weight": "decoder.conv_out.weight", + "decoder.norm_out.bias": "decoder.conv_norm_out.bias", + "decoder.norm_out.weight": "decoder.conv_norm_out.weight", "decoder.mid.attn_1.k.bias": "decoder.blocks.1.transformer_blocks.0.to_k.bias", "decoder.mid.attn_1.k.weight": "decoder.blocks.1.transformer_blocks.0.to_k.weight", "decoder.mid.attn_1.norm.bias": "decoder.blocks.1.norm.bias", @@ -31,8 +33,6 @@ "decoder.mid.block_2.norm1.weight": "decoder.blocks.2.norm1.weight", "decoder.mid.block_2.norm2.bias": "decoder.blocks.2.norm2.bias", "decoder.mid.block_2.norm2.weight": "decoder.blocks.2.norm2.weight", - "decoder.norm_out.bias": "decoder.conv_norm_out.bias", - "decoder.norm_out.weight": "decoder.conv_norm_out.weight", "decoder.up.0.block.0.conv1.bias": "decoder.blocks.15.conv1.bias", "decoder.up.0.block.0.conv1.weight": "decoder.blocks.15.conv1.weight", "decoder.up.0.block.0.conv2.bias": "decoder.blocks.15.conv2.bias", @@ -143,6 +143,8 @@ "encoder.conv_in.weight": "encoder.conv_in.weight", "encoder.conv_out.bias": "encoder.conv_out.bias", "encoder.conv_out.weight": "encoder.conv_out.weight", + "encoder.norm_out.bias": "encoder.conv_norm_out.bias", + "encoder.norm_out.weight": "encoder.conv_norm_out.weight", "encoder.down.0.block.0.conv1.bias": "encoder.blocks.0.conv1.bias", "encoder.down.0.block.0.conv1.weight": "encoder.blocks.0.conv1.weight", "encoder.down.0.block.0.conv2.bias": "encoder.blocks.0.conv2.bias", @@ -242,9 +244,255 @@ "encoder.mid.block_2.norm1.bias": "encoder.blocks.13.norm1.bias", "encoder.mid.block_2.norm1.weight": "encoder.blocks.13.norm1.weight", "encoder.mid.block_2.norm2.bias": "encoder.blocks.13.norm2.bias", - "encoder.mid.block_2.norm2.weight": "encoder.blocks.13.norm2.weight", - "encoder.norm_out.bias": "encoder.conv_norm_out.bias", - "encoder.norm_out.weight": "encoder.conv_norm_out.weight" + "encoder.mid.block_2.norm2.weight": "encoder.blocks.13.norm2.weight" + } + }, + "diffusers": { + "rename_dict": { + "decoder.conv_in.bias": "decoder.conv_in.bias", + "decoder.conv_in.weight": "decoder.conv_in.weight", + "decoder.conv_out.bias": "decoder.conv_out.bias", + "decoder.conv_out.weight": "decoder.conv_out.weight", + "decoder.conv_norm_out.bias": "decoder.conv_norm_out.bias", + "decoder.conv_norm_out.weight": "decoder.conv_norm_out.weight", + "decoder.mid_block.attentions.0.to_k.bias": "decoder.blocks.1.transformer_blocks.0.to_k.bias", + "decoder.mid_block.attentions.0.to_k.weight": "decoder.blocks.1.transformer_blocks.0.to_k.weight", + "decoder.mid_block.attentions.0.group_norm.bias": "decoder.blocks.1.norm.bias", + "decoder.mid_block.attentions.0.group_norm.weight": "decoder.blocks.1.norm.weight", + "decoder.mid_block.attentions.0.to_out.0.bias": "decoder.blocks.1.transformer_blocks.0.to_out.bias", + "decoder.mid_block.attentions.0.to_out.0.weight": "decoder.blocks.1.transformer_blocks.0.to_out.weight", + "decoder.mid_block.attentions.0.to_q.bias": "decoder.blocks.1.transformer_blocks.0.to_q.bias", + "decoder.mid_block.attentions.0.to_q.weight": "decoder.blocks.1.transformer_blocks.0.to_q.weight", + "decoder.mid_block.attentions.0.to_v.bias": "decoder.blocks.1.transformer_blocks.0.to_v.bias", + "decoder.mid_block.attentions.0.to_v.weight": "decoder.blocks.1.transformer_blocks.0.to_v.weight", + "decoder.mid_block.resnets.0.conv1.bias": "decoder.blocks.0.conv1.bias", + "decoder.mid_block.resnets.0.conv1.weight": "decoder.blocks.0.conv1.weight", + "decoder.mid_block.resnets.0.conv2.bias": "decoder.blocks.0.conv2.bias", + "decoder.mid_block.resnets.0.conv2.weight": "decoder.blocks.0.conv2.weight", + "decoder.mid_block.resnets.0.norm1.bias": "decoder.blocks.0.norm1.bias", + "decoder.mid_block.resnets.0.norm1.weight": "decoder.blocks.0.norm1.weight", + "decoder.mid_block.resnets.0.norm2.bias": "decoder.blocks.0.norm2.bias", + "decoder.mid_block.resnets.0.norm2.weight": "decoder.blocks.0.norm2.weight", + "decoder.mid_block.resnets.1.conv1.bias": "decoder.blocks.2.conv1.bias", + "decoder.mid_block.resnets.1.conv1.weight": "decoder.blocks.2.conv1.weight", + "decoder.mid_block.resnets.1.conv2.bias": "decoder.blocks.2.conv2.bias", + "decoder.mid_block.resnets.1.conv2.weight": "decoder.blocks.2.conv2.weight", + "decoder.mid_block.resnets.1.norm1.bias": "decoder.blocks.2.norm1.bias", + "decoder.mid_block.resnets.1.norm1.weight": "decoder.blocks.2.norm1.weight", + "decoder.mid_block.resnets.1.norm2.bias": "decoder.blocks.2.norm2.bias", + "decoder.mid_block.resnets.1.norm2.weight": "decoder.blocks.2.norm2.weight", + "decoder.up_blocks.0.resnets.0.conv1.bias": "decoder.blocks.3.conv1.bias", + "decoder.up_blocks.0.resnets.0.conv1.weight": "decoder.blocks.3.conv1.weight", + "decoder.up_blocks.0.resnets.0.conv2.bias": "decoder.blocks.3.conv2.bias", + "decoder.up_blocks.0.resnets.0.conv2.weight": "decoder.blocks.3.conv2.weight", + "decoder.up_blocks.0.resnets.0.norm1.bias": "decoder.blocks.3.norm1.bias", + "decoder.up_blocks.0.resnets.0.norm1.weight": "decoder.blocks.3.norm1.weight", + "decoder.up_blocks.0.resnets.0.norm2.bias": "decoder.blocks.3.norm2.bias", + "decoder.up_blocks.0.resnets.0.norm2.weight": "decoder.blocks.3.norm2.weight", + "decoder.up_blocks.0.resnets.1.conv1.bias": "decoder.blocks.4.conv1.bias", + "decoder.up_blocks.0.resnets.1.conv1.weight": "decoder.blocks.4.conv1.weight", + "decoder.up_blocks.0.resnets.1.conv2.bias": "decoder.blocks.4.conv2.bias", + "decoder.up_blocks.0.resnets.1.conv2.weight": "decoder.blocks.4.conv2.weight", + "decoder.up_blocks.0.resnets.1.norm1.bias": "decoder.blocks.4.norm1.bias", + "decoder.up_blocks.0.resnets.1.norm1.weight": "decoder.blocks.4.norm1.weight", + "decoder.up_blocks.0.resnets.1.norm2.bias": "decoder.blocks.4.norm2.bias", + "decoder.up_blocks.0.resnets.1.norm2.weight": "decoder.blocks.4.norm2.weight", + "decoder.up_blocks.0.resnets.2.conv1.bias": "decoder.blocks.5.conv1.bias", + "decoder.up_blocks.0.resnets.2.conv1.weight": "decoder.blocks.5.conv1.weight", + "decoder.up_blocks.0.resnets.2.conv2.bias": "decoder.blocks.5.conv2.bias", + "decoder.up_blocks.0.resnets.2.conv2.weight": "decoder.blocks.5.conv2.weight", + "decoder.up_blocks.0.resnets.2.norm1.bias": "decoder.blocks.5.norm1.bias", + "decoder.up_blocks.0.resnets.2.norm1.weight": "decoder.blocks.5.norm1.weight", + "decoder.up_blocks.0.resnets.2.norm2.bias": "decoder.blocks.5.norm2.bias", + "decoder.up_blocks.0.resnets.2.norm2.weight": "decoder.blocks.5.norm2.weight", + "decoder.up_blocks.0.upsamplers.0.conv.bias": "decoder.blocks.6.conv.bias", + "decoder.up_blocks.0.upsamplers.0.conv.weight": "decoder.blocks.6.conv.weight", + "decoder.up_blocks.1.resnets.0.conv1.bias": "decoder.blocks.7.conv1.bias", + "decoder.up_blocks.1.resnets.0.conv1.weight": "decoder.blocks.7.conv1.weight", + "decoder.up_blocks.1.resnets.0.conv2.bias": "decoder.blocks.7.conv2.bias", + "decoder.up_blocks.1.resnets.0.conv2.weight": "decoder.blocks.7.conv2.weight", + "decoder.up_blocks.1.resnets.0.norm1.bias": "decoder.blocks.7.norm1.bias", + "decoder.up_blocks.1.resnets.0.norm1.weight": "decoder.blocks.7.norm1.weight", + "decoder.up_blocks.1.resnets.0.norm2.bias": "decoder.blocks.7.norm2.bias", + "decoder.up_blocks.1.resnets.0.norm2.weight": "decoder.blocks.7.norm2.weight", + "decoder.up_blocks.1.resnets.1.conv1.bias": "decoder.blocks.8.conv1.bias", + "decoder.up_blocks.1.resnets.1.conv1.weight": "decoder.blocks.8.conv1.weight", + "decoder.up_blocks.1.resnets.1.conv2.bias": "decoder.blocks.8.conv2.bias", + "decoder.up_blocks.1.resnets.1.conv2.weight": "decoder.blocks.8.conv2.weight", + "decoder.up_blocks.1.resnets.1.norm1.bias": "decoder.blocks.8.norm1.bias", + "decoder.up_blocks.1.resnets.1.norm1.weight": "decoder.blocks.8.norm1.weight", + "decoder.up_blocks.1.resnets.1.norm2.bias": "decoder.blocks.8.norm2.bias", + "decoder.up_blocks.1.resnets.1.norm2.weight": "decoder.blocks.8.norm2.weight", + "decoder.up_blocks.1.resnets.2.conv1.bias": "decoder.blocks.9.conv1.bias", + "decoder.up_blocks.1.resnets.2.conv1.weight": "decoder.blocks.9.conv1.weight", + "decoder.up_blocks.1.resnets.2.conv2.bias": "decoder.blocks.9.conv2.bias", + "decoder.up_blocks.1.resnets.2.conv2.weight": "decoder.blocks.9.conv2.weight", + "decoder.up_blocks.1.resnets.2.norm1.bias": "decoder.blocks.9.norm1.bias", + "decoder.up_blocks.1.resnets.2.norm1.weight": "decoder.blocks.9.norm1.weight", + "decoder.up_blocks.1.resnets.2.norm2.bias": "decoder.blocks.9.norm2.bias", + "decoder.up_blocks.1.resnets.2.norm2.weight": "decoder.blocks.9.norm2.weight", + "decoder.up_blocks.1.upsamplers.0.conv.bias": "decoder.blocks.10.conv.bias", + "decoder.up_blocks.1.upsamplers.0.conv.weight": "decoder.blocks.10.conv.weight", + "decoder.up_blocks.2.resnets.0.conv1.bias": "decoder.blocks.11.conv1.bias", + "decoder.up_blocks.2.resnets.0.conv1.weight": "decoder.blocks.11.conv1.weight", + "decoder.up_blocks.2.resnets.0.conv2.bias": "decoder.blocks.11.conv2.bias", + "decoder.up_blocks.2.resnets.0.conv2.weight": "decoder.blocks.11.conv2.weight", + "decoder.up_blocks.2.resnets.0.conv_shortcut.bias": "decoder.blocks.11.conv_shortcut.bias", + "decoder.up_blocks.2.resnets.0.conv_shortcut.weight": "decoder.blocks.11.conv_shortcut.weight", + "decoder.up_blocks.2.resnets.0.norm1.bias": "decoder.blocks.11.norm1.bias", + "decoder.up_blocks.2.resnets.0.norm1.weight": "decoder.blocks.11.norm1.weight", + "decoder.up_blocks.2.resnets.0.norm2.bias": "decoder.blocks.11.norm2.bias", + "decoder.up_blocks.2.resnets.0.norm2.weight": "decoder.blocks.11.norm2.weight", + "decoder.up_blocks.2.resnets.1.conv1.bias": "decoder.blocks.12.conv1.bias", + "decoder.up_blocks.2.resnets.1.conv1.weight": "decoder.blocks.12.conv1.weight", + "decoder.up_blocks.2.resnets.1.conv2.bias": "decoder.blocks.12.conv2.bias", + "decoder.up_blocks.2.resnets.1.conv2.weight": "decoder.blocks.12.conv2.weight", + "decoder.up_blocks.2.resnets.1.norm1.bias": "decoder.blocks.12.norm1.bias", + "decoder.up_blocks.2.resnets.1.norm1.weight": "decoder.blocks.12.norm1.weight", + "decoder.up_blocks.2.resnets.1.norm2.bias": "decoder.blocks.12.norm2.bias", + "decoder.up_blocks.2.resnets.1.norm2.weight": "decoder.blocks.12.norm2.weight", + "decoder.up_blocks.2.resnets.2.conv1.bias": "decoder.blocks.13.conv1.bias", + "decoder.up_blocks.2.resnets.2.conv1.weight": "decoder.blocks.13.conv1.weight", + "decoder.up_blocks.2.resnets.2.conv2.bias": "decoder.blocks.13.conv2.bias", + "decoder.up_blocks.2.resnets.2.conv2.weight": "decoder.blocks.13.conv2.weight", + "decoder.up_blocks.2.resnets.2.norm1.bias": "decoder.blocks.13.norm1.bias", + "decoder.up_blocks.2.resnets.2.norm1.weight": "decoder.blocks.13.norm1.weight", + "decoder.up_blocks.2.resnets.2.norm2.bias": "decoder.blocks.13.norm2.bias", + "decoder.up_blocks.2.resnets.2.norm2.weight": "decoder.blocks.13.norm2.weight", + "decoder.up_blocks.2.upsamplers.0.conv.bias": "decoder.blocks.14.conv.bias", + "decoder.up_blocks.2.upsamplers.0.conv.weight": "decoder.blocks.14.conv.weight", + "decoder.up_blocks.3.resnets.0.conv1.bias": "decoder.blocks.15.conv1.bias", + "decoder.up_blocks.3.resnets.0.conv1.weight": "decoder.blocks.15.conv1.weight", + "decoder.up_blocks.3.resnets.0.conv2.bias": "decoder.blocks.15.conv2.bias", + "decoder.up_blocks.3.resnets.0.conv2.weight": "decoder.blocks.15.conv2.weight", + "decoder.up_blocks.3.resnets.0.conv_shortcut.bias": "decoder.blocks.15.conv_shortcut.bias", + "decoder.up_blocks.3.resnets.0.conv_shortcut.weight": "decoder.blocks.15.conv_shortcut.weight", + "decoder.up_blocks.3.resnets.0.norm1.bias": "decoder.blocks.15.norm1.bias", + "decoder.up_blocks.3.resnets.0.norm1.weight": "decoder.blocks.15.norm1.weight", + "decoder.up_blocks.3.resnets.0.norm2.bias": "decoder.blocks.15.norm2.bias", + "decoder.up_blocks.3.resnets.0.norm2.weight": "decoder.blocks.15.norm2.weight", + "decoder.up_blocks.3.resnets.1.conv1.bias": "decoder.blocks.16.conv1.bias", + "decoder.up_blocks.3.resnets.1.conv1.weight": "decoder.blocks.16.conv1.weight", + "decoder.up_blocks.3.resnets.1.conv2.bias": "decoder.blocks.16.conv2.bias", + "decoder.up_blocks.3.resnets.1.conv2.weight": "decoder.blocks.16.conv2.weight", + "decoder.up_blocks.3.resnets.1.norm1.bias": "decoder.blocks.16.norm1.bias", + "decoder.up_blocks.3.resnets.1.norm1.weight": "decoder.blocks.16.norm1.weight", + "decoder.up_blocks.3.resnets.1.norm2.bias": "decoder.blocks.16.norm2.bias", + "decoder.up_blocks.3.resnets.1.norm2.weight": "decoder.blocks.16.norm2.weight", + "decoder.up_blocks.3.resnets.2.conv1.bias": "decoder.blocks.17.conv1.bias", + "decoder.up_blocks.3.resnets.2.conv1.weight": "decoder.blocks.17.conv1.weight", + "decoder.up_blocks.3.resnets.2.conv2.bias": "decoder.blocks.17.conv2.bias", + "decoder.up_blocks.3.resnets.2.conv2.weight": "decoder.blocks.17.conv2.weight", + "decoder.up_blocks.3.resnets.2.norm1.bias": "decoder.blocks.17.norm1.bias", + "decoder.up_blocks.3.resnets.2.norm1.weight": "decoder.blocks.17.norm1.weight", + "decoder.up_blocks.3.resnets.2.norm2.bias": "decoder.blocks.17.norm2.bias", + "decoder.up_blocks.3.resnets.2.norm2.weight": "decoder.blocks.17.norm2.weight", + "encoder.conv_in.bias": "encoder.conv_in.bias", + "encoder.conv_in.weight": "encoder.conv_in.weight", + "encoder.conv_out.bias": "encoder.conv_out.bias", + "encoder.conv_out.weight": "encoder.conv_out.weight", + "encoder.conv_norm_out.bias": "encoder.conv_norm_out.bias", + "encoder.conv_norm_out.weight": "encoder.conv_norm_out.weight", + "encoder.down_blocks.0.resnets.0.conv1.bias": "encoder.blocks.0.conv1.bias", + "encoder.down_blocks.0.resnets.0.conv1.weight": "encoder.blocks.0.conv1.weight", + "encoder.down_blocks.0.resnets.0.conv2.bias": "encoder.blocks.0.conv2.bias", + "encoder.down_blocks.0.resnets.0.conv2.weight": "encoder.blocks.0.conv2.weight", + "encoder.down_blocks.0.resnets.0.norm1.bias": "encoder.blocks.0.norm1.bias", + "encoder.down_blocks.0.resnets.0.norm1.weight": "encoder.blocks.0.norm1.weight", + "encoder.down_blocks.0.resnets.0.norm2.bias": "encoder.blocks.0.norm2.bias", + "encoder.down_blocks.0.resnets.0.norm2.weight": "encoder.blocks.0.norm2.weight", + "encoder.down_blocks.0.resnets.1.conv1.bias": "encoder.blocks.1.conv1.bias", + "encoder.down_blocks.0.resnets.1.conv1.weight": "encoder.blocks.1.conv1.weight", + "encoder.down_blocks.0.resnets.1.conv2.bias": "encoder.blocks.1.conv2.bias", + "encoder.down_blocks.0.resnets.1.conv2.weight": "encoder.blocks.1.conv2.weight", + "encoder.down_blocks.0.resnets.1.norm1.bias": "encoder.blocks.1.norm1.bias", + "encoder.down_blocks.0.resnets.1.norm1.weight": "encoder.blocks.1.norm1.weight", + "encoder.down_blocks.0.resnets.1.norm2.bias": "encoder.blocks.1.norm2.bias", + "encoder.down_blocks.0.resnets.1.norm2.weight": "encoder.blocks.1.norm2.weight", + "encoder.down_blocks.0.downsamplers.0.conv.bias": "encoder.blocks.2.conv.bias", + "encoder.down_blocks.0.downsamplers.0.conv.weight": "encoder.blocks.2.conv.weight", + "encoder.down_blocks.1.resnets.0.conv1.bias": "encoder.blocks.3.conv1.bias", + "encoder.down_blocks.1.resnets.0.conv1.weight": "encoder.blocks.3.conv1.weight", + "encoder.down_blocks.1.resnets.0.conv2.bias": "encoder.blocks.3.conv2.bias", + "encoder.down_blocks.1.resnets.0.conv2.weight": "encoder.blocks.3.conv2.weight", + "encoder.down_blocks.1.resnets.0.conv_shortcut.bias": "encoder.blocks.3.conv_shortcut.bias", + "encoder.down_blocks.1.resnets.0.conv_shortcut.weight": "encoder.blocks.3.conv_shortcut.weight", + "encoder.down_blocks.1.resnets.0.norm1.bias": "encoder.blocks.3.norm1.bias", + "encoder.down_blocks.1.resnets.0.norm1.weight": "encoder.blocks.3.norm1.weight", + "encoder.down_blocks.1.resnets.0.norm2.bias": "encoder.blocks.3.norm2.bias", + "encoder.down_blocks.1.resnets.0.norm2.weight": "encoder.blocks.3.norm2.weight", + "encoder.down_blocks.1.resnets.1.conv1.bias": "encoder.blocks.4.conv1.bias", + "encoder.down_blocks.1.resnets.1.conv1.weight": "encoder.blocks.4.conv1.weight", + "encoder.down_blocks.1.resnets.1.conv2.bias": "encoder.blocks.4.conv2.bias", + "encoder.down_blocks.1.resnets.1.conv2.weight": "encoder.blocks.4.conv2.weight", + "encoder.down_blocks.1.resnets.1.norm1.bias": "encoder.blocks.4.norm1.bias", + "encoder.down_blocks.1.resnets.1.norm1.weight": "encoder.blocks.4.norm1.weight", + "encoder.down_blocks.1.resnets.1.norm2.bias": "encoder.blocks.4.norm2.bias", + "encoder.down_blocks.1.resnets.1.norm2.weight": "encoder.blocks.4.norm2.weight", + "encoder.down_blocks.1.downsamplers.0.conv.bias": "encoder.blocks.5.conv.bias", + "encoder.down_blocks.1.downsamplers.0.conv.weight": "encoder.blocks.5.conv.weight", + "encoder.down_blocks.2.resnets.0.conv1.bias": "encoder.blocks.6.conv1.bias", + "encoder.down_blocks.2.resnets.0.conv1.weight": "encoder.blocks.6.conv1.weight", + "encoder.down_blocks.2.resnets.0.conv2.bias": "encoder.blocks.6.conv2.bias", + "encoder.down_blocks.2.resnets.0.conv2.weight": "encoder.blocks.6.conv2.weight", + "encoder.down_blocks.2.resnets.0.conv_shortcut.bias": "encoder.blocks.6.conv_shortcut.bias", + "encoder.down_blocks.2.resnets.0.conv_shortcut.weight": "encoder.blocks.6.conv_shortcut.weight", + "encoder.down_blocks.2.resnets.0.norm1.bias": "encoder.blocks.6.norm1.bias", + "encoder.down_blocks.2.resnets.0.norm1.weight": "encoder.blocks.6.norm1.weight", + "encoder.down_blocks.2.resnets.0.norm2.bias": "encoder.blocks.6.norm2.bias", + "encoder.down_blocks.2.resnets.0.norm2.weight": "encoder.blocks.6.norm2.weight", + "encoder.down_blocks.2.resnets.1.conv1.bias": "encoder.blocks.7.conv1.bias", + "encoder.down_blocks.2.resnets.1.conv1.weight": "encoder.blocks.7.conv1.weight", + "encoder.down_blocks.2.resnets.1.conv2.bias": "encoder.blocks.7.conv2.bias", + "encoder.down_blocks.2.resnets.1.conv2.weight": "encoder.blocks.7.conv2.weight", + "encoder.down_blocks.2.resnets.1.norm1.bias": "encoder.blocks.7.norm1.bias", + "encoder.down_blocks.2.resnets.1.norm1.weight": "encoder.blocks.7.norm1.weight", + "encoder.down_blocks.2.resnets.1.norm2.bias": "encoder.blocks.7.norm2.bias", + "encoder.down_blocks.2.resnets.1.norm2.weight": "encoder.blocks.7.norm2.weight", + "encoder.down_blocks.2.downsamplers.0.conv.bias": "encoder.blocks.8.conv.bias", + "encoder.down_blocks.2.downsamplers.0.conv.weight": "encoder.blocks.8.conv.weight", + "encoder.down_blocks.3.resnets.0.conv1.bias": "encoder.blocks.9.conv1.bias", + "encoder.down_blocks.3.resnets.0.conv1.weight": "encoder.blocks.9.conv1.weight", + "encoder.down_blocks.3.resnets.0.conv2.bias": "encoder.blocks.9.conv2.bias", + "encoder.down_blocks.3.resnets.0.conv2.weight": "encoder.blocks.9.conv2.weight", + "encoder.down_blocks.3.resnets.0.norm1.bias": "encoder.blocks.9.norm1.bias", + "encoder.down_blocks.3.resnets.0.norm1.weight": "encoder.blocks.9.norm1.weight", + "encoder.down_blocks.3.resnets.0.norm2.bias": "encoder.blocks.9.norm2.bias", + "encoder.down_blocks.3.resnets.0.norm2.weight": "encoder.blocks.9.norm2.weight", + "encoder.down_blocks.3.resnets.1.conv1.bias": "encoder.blocks.10.conv1.bias", + "encoder.down_blocks.3.resnets.1.conv1.weight": "encoder.blocks.10.conv1.weight", + "encoder.down_blocks.3.resnets.1.conv2.bias": "encoder.blocks.10.conv2.bias", + "encoder.down_blocks.3.resnets.1.conv2.weight": "encoder.blocks.10.conv2.weight", + "encoder.down_blocks.3.resnets.1.norm1.bias": "encoder.blocks.10.norm1.bias", + "encoder.down_blocks.3.resnets.1.norm1.weight": "encoder.blocks.10.norm1.weight", + "encoder.down_blocks.3.resnets.1.norm2.bias": "encoder.blocks.10.norm2.bias", + "encoder.down_blocks.3.resnets.1.norm2.weight": "encoder.blocks.10.norm2.weight", + "encoder.mid_block.attentions.0.to_k.bias": "encoder.blocks.12.transformer_blocks.0.to_k.bias", + "encoder.mid_block.attentions.0.to_k.weight": "encoder.blocks.12.transformer_blocks.0.to_k.weight", + "encoder.mid_block.attentions.0.group_norm.bias": "encoder.blocks.12.norm.bias", + "encoder.mid_block.attentions.0.group_norm.weight": "encoder.blocks.12.norm.weight", + "encoder.mid_block.attentions.0.to_out.0.bias": "encoder.blocks.12.transformer_blocks.0.to_out.bias", + "encoder.mid_block.attentions.0.to_out.0.weight": "encoder.blocks.12.transformer_blocks.0.to_out.weight", + "encoder.mid_block.attentions.0.to_q.bias": "encoder.blocks.12.transformer_blocks.0.to_q.bias", + "encoder.mid_block.attentions.0.to_q.weight": "encoder.blocks.12.transformer_blocks.0.to_q.weight", + "encoder.mid_block.attentions.0.to_v.bias": "encoder.blocks.12.transformer_blocks.0.to_v.bias", + "encoder.mid_block.attentions.0.to_v.weight": "encoder.blocks.12.transformer_blocks.0.to_v.weight", + "encoder.mid_block.resnets.0.conv1.bias": "encoder.blocks.11.conv1.bias", + "encoder.mid_block.resnets.0.conv1.weight": "encoder.blocks.11.conv1.weight", + "encoder.mid_block.resnets.0.conv2.bias": "encoder.blocks.11.conv2.bias", + "encoder.mid_block.resnets.0.conv2.weight": "encoder.blocks.11.conv2.weight", + "encoder.mid_block.resnets.0.norm1.bias": "encoder.blocks.11.norm1.bias", + "encoder.mid_block.resnets.0.norm1.weight": "encoder.blocks.11.norm1.weight", + "encoder.mid_block.resnets.0.norm2.bias": "encoder.blocks.11.norm2.bias", + "encoder.mid_block.resnets.0.norm2.weight": "encoder.blocks.11.norm2.weight", + "encoder.mid_block.resnets.1.conv1.bias": "encoder.blocks.13.conv1.bias", + "encoder.mid_block.resnets.1.conv1.weight": "encoder.blocks.13.conv1.weight", + "encoder.mid_block.resnets.1.conv2.bias": "encoder.blocks.13.conv2.bias", + "encoder.mid_block.resnets.1.conv2.weight": "encoder.blocks.13.conv2.weight", + "encoder.mid_block.resnets.1.norm1.bias": "encoder.blocks.13.norm1.bias", + "encoder.mid_block.resnets.1.norm1.weight": "encoder.blocks.13.norm1.weight", + "encoder.mid_block.resnets.1.norm2.bias": "encoder.blocks.13.norm2.bias", + "encoder.mid_block.resnets.1.norm2.weight": "encoder.blocks.13.norm2.weight" } } } \ No newline at end of file diff --git a/diffsynth_engine/models/flux/flux_controlnet.py b/diffsynth_engine/models/flux/flux_controlnet.py index 3a1d33aa..889c28ec 100644 --- a/diffsynth_engine/models/flux/flux_controlnet.py +++ b/diffsynth_engine/models/flux/flux_controlnet.py @@ -119,18 +119,16 @@ def patchify(self, hidden_states): def forward( self, - hidden_states, - control_condition, - control_scale, - timestep, - prompt_emb, - pooled_prompt_emb, - guidance, - image_ids, - text_ids, + hidden_states: torch.Tensor, + control_condition: torch.Tensor, + control_scale: float, + timestep: torch.Tensor, + prompt_emb: torch.Tensor, + pooled_prompt_emb: torch.Tensor, + image_ids: torch.Tensor, + text_ids: torch.Tensor, + guidance: torch.Tensor, ): - hidden_states = self.patchify(hidden_states) - control_condition = self.patchify(control_condition) hidden_states = self.x_embedder(hidden_states) + self.controlnet_x_embedder(control_condition) condition = ( self.time_embedder(timestep, hidden_states.dtype) diff --git a/diffsynth_engine/models/flux/flux_dit.py b/diffsynth_engine/models/flux/flux_dit.py index 6cc872a9..65333e02 100644 --- a/diffsynth_engine/models/flux/flux_dit.py +++ b/diffsynth_engine/models/flux/flux_dit.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import numpy as np -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from einops import rearrange from diffsynth_engine.models.basic.transformer_helper import ( @@ -245,7 +245,7 @@ def __init__( self.ff_a = nn.Sequential( nn.Linear(dim, dim * 4, device=device, dtype=dtype), nn.GELU(approximate="tanh"), - nn.Linear(dim * 4, dim, device=device, dtype=dtype) + nn.Linear(dim * 4, dim, device=device, dtype=dtype), ) # Text self.norm_msa_b = AdaLayerNormZero(dim, device=device, dtype=dtype) @@ -395,21 +395,19 @@ def prepare_image_ids(latents: torch.Tensor): def forward( self, - hidden_states, - timestep, - prompt_emb, - pooled_prompt_emb, - image_emb, - guidance, - text_ids, - image_ids=None, - controlnet_double_block_output=None, - controlnet_single_block_output=None, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + prompt_emb: torch.Tensor, + pooled_prompt_emb: torch.Tensor, + image_ids: torch.Tensor, + text_ids: torch.Tensor, + guidance: torch.Tensor, + image_emb: torch.Tensor | None = None, + controlnet_double_block_output: List[torch.Tensor] | None = None, + controlnet_single_block_output: List[torch.Tensor] | None = None, **kwargs, ): - h, w = hidden_states.shape[-2:] - if image_ids is None: - image_ids = self.prepare_image_ids(hidden_states) + image_seq_len = hidden_states.shape[1] controlnet_double_block_output = ( controlnet_double_block_output if controlnet_double_block_output is not None else () ) @@ -428,10 +426,10 @@ def forward( timestep, prompt_emb, pooled_prompt_emb, - image_emb, - guidance, - text_ids, image_ids, + text_ids, + guidance, + image_emb, *controlnet_double_block_output, *controlnet_single_block_output, ), @@ -448,7 +446,6 @@ def forward( rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) text_rope_emb = rope_emb[:, :, : text_ids.size(1)] image_rope_emb = rope_emb[:, :, text_ids.size(1) :] - hidden_states = self.patchify(hidden_states) with sequence_parallel( ( @@ -489,9 +486,8 @@ def forward( hidden_states = hidden_states[:, prompt_emb.shape[1] :] hidden_states = self.final_norm_out(hidden_states, conditioning) hidden_states = self.final_proj_out(hidden_states) - (hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(h * w // 4,)) + (hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(image_seq_len,)) - hidden_states = self.unpatchify(hidden_states, h, w) (hidden_states,) = cfg_parallel_unshard((hidden_states,), use_cfg=use_cfg) return hidden_states diff --git a/diffsynth_engine/models/flux/flux_dit_fbcache.py b/diffsynth_engine/models/flux/flux_dit_fbcache.py index 85dc8a15..1b7d59d1 100644 --- a/diffsynth_engine/models/flux/flux_dit_fbcache.py +++ b/diffsynth_engine/models/flux/flux_dit_fbcache.py @@ -1,6 +1,6 @@ import torch import numpy as np -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from diffsynth_engine.utils.gguf import gguf_inference from diffsynth_engine.utils.fp8_linear import fp8_inference @@ -48,21 +48,19 @@ def refresh_cache_status(self, num_inference_steps): def forward( self, - hidden_states, - timestep, - prompt_emb, - pooled_prompt_emb, - image_emb, - guidance, - text_ids, - image_ids=None, - controlnet_double_block_output=None, - controlnet_single_block_output=None, + hidden_states: torch.Tensor, + timestep: torch.Tensor, + prompt_emb: torch.Tensor, + pooled_prompt_emb: torch.Tensor, + image_ids: torch.Tensor, + text_ids: torch.Tensor, + guidance: torch.Tensor, + image_emb: torch.Tensor | None = None, + controlnet_double_block_output: List[torch.Tensor] | None = None, + controlnet_single_block_output: List[torch.Tensor] | None = None, **kwargs, ): - h, w = hidden_states.shape[-2:] - if image_ids is None: - image_ids = self.prepare_image_ids(hidden_states) + image_seq_len = hidden_states.shape[1] controlnet_double_block_output = ( controlnet_double_block_output if controlnet_double_block_output is not None else () ) @@ -81,10 +79,10 @@ def forward( timestep, prompt_emb, pooled_prompt_emb, - image_emb, - guidance, - text_ids, image_ids, + text_ids, + guidance, + image_emb, *controlnet_double_block_output, *controlnet_single_block_output, ), @@ -101,7 +99,6 @@ def forward( rope_emb = self.pos_embedder(torch.cat((text_ids, image_ids), dim=1)) text_rope_emb = rope_emb[:, :, : text_ids.size(1)] image_rope_emb = rope_emb[:, :, text_ids.size(1) :] - hidden_states = self.patchify(hidden_states) with sequence_parallel( ( @@ -131,7 +128,7 @@ def forward( first_hidden_states_residual = hidden_states - original_hidden_states (first_hidden_states_residual,) = sequence_parallel_unshard( - (first_hidden_states_residual,), seq_dims=(1,), seq_lens=(h * w // 4,) + (first_hidden_states_residual,), seq_dims=(1,), seq_lens=(image_seq_len,) ) if self.step_count == 0 or self.step_count == (self.num_inference_steps - 1): @@ -172,9 +169,8 @@ def forward( hidden_states = self.final_norm_out(hidden_states, conditioning) hidden_states = self.final_proj_out(hidden_states) - (hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(h * w // 4,)) + (hidden_states,) = sequence_parallel_unshard((hidden_states,), seq_dims=(1,), seq_lens=(image_seq_len,)) - hidden_states = self.unpatchify(hidden_states, h, w) (hidden_states,) = cfg_parallel_unshard((hidden_states,), use_cfg=use_cfg) return hidden_states diff --git a/diffsynth_engine/models/flux/flux_vae.py b/diffsynth_engine/models/flux/flux_vae.py index 9af50bc3..fea42711 100644 --- a/diffsynth_engine/models/flux/flux_vae.py +++ b/diffsynth_engine/models/flux/flux_vae.py @@ -25,11 +25,29 @@ def _from_civitai(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch. new_state_dict[name_] = param return new_state_dict + def _from_diffusers(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + rename_dict = config["diffusers"]["rename_dict"] + new_state_dict = {} + for name, param in state_dict.items(): + if name not in rename_dict: + continue + name_ = rename_dict[name] + if "transformer_blocks" in name_: + param = param.squeeze() + new_state_dict[name_] = param + return new_state_dict + def convert(self, state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: assert self.has_decoder or self.has_encoder, "Either decoder or encoder must be present" - if "decoder.conv_in.weight" in state_dict or "encoder.conv_in.weight" in state_dict: + if "decoder.up.0.block.0.conv1.weight" in state_dict or "encoder.down.0.block.0.conv1.weight" in state_dict: state_dict = self._from_civitai(state_dict) logger.info("use civitai format state dict") + elif ( + "decoder.up_blocks.0.resnets.0.conv1.weight" in state_dict + or "encoder.down_blocks.0.resnets.0.conv1.weight" in state_dict + ): + state_dict = self._from_diffusers(state_dict) + logger.info("use diffusers format state dict") else: logger.info("use diffsynth format state dict") return self._filter(state_dict) diff --git a/diffsynth_engine/pipelines/flux_image.py b/diffsynth_engine/pipelines/flux_image.py index e0460104..8901920c 100644 --- a/diffsynth_engine/pipelines/flux_image.py +++ b/diffsynth_engine/pipelines/flux_image.py @@ -37,6 +37,8 @@ with open(FLUX_DIT_CONFIG_FILE, "r") as f: config = json.load(f) +PREFERRED_KONTEXT_RESOLUTIONS = config["preferred_kontext_resolutions"] + class FluxLoRAConverter(LoRAStateDictConverter): def _from_kohya(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]: @@ -612,7 +614,7 @@ def encode_prompt(self, prompt, clip_skip: int = 2): return prompt_emb, add_text_embeds def prepare_extra_input(self, latents, positive_prompt_emb, guidance=1.0): - image_ids = FluxDiT.prepare_image_ids(latents) + image_ids = self.dit.prepare_image_ids(latents) guidance = torch.tensor([guidance] * latents.shape[0], device=latents.device, dtype=latents.dtype) text_ids = torch.zeros(positive_prompt_emb.shape[0], positive_prompt_emb.shape[1], 3).to( device=self.device, dtype=positive_prompt_emb.dtype @@ -639,45 +641,45 @@ def predict_noise_with_cfg( ): if cfg_scale <= 1.0: return self.predict_noise( - latents, - timestep, - positive_prompt_emb, - positive_add_text_embeds, - image_emb, - image_ids, - text_ids, - guidance, - controlnet_params, - current_step, - total_step, + latents=latents, + timestep=timestep, + prompt_emb=positive_prompt_emb, + add_text_embeds=positive_add_text_embeds, + image_emb=image_emb, + image_ids=image_ids, + text_ids=text_ids, + guidance=guidance, + controlnet_params=controlnet_params, + current_step=current_step, + total_step=total_step, ) if not batch_cfg: # cfg by predict noise one by one positive_noise_pred = self.predict_noise( - latents, - timestep, - positive_prompt_emb, - positive_add_text_embeds, - image_emb, - image_ids, - text_ids, - guidance, - controlnet_params, - current_step, - total_step, + latents=latents, + timestep=timestep, + prompt_emb=positive_prompt_emb, + add_text_embeds=positive_add_text_embeds, + image_emb=image_emb, + image_ids=image_ids, + text_ids=text_ids, + guidance=guidance, + controlnet_params=controlnet_params, + current_step=current_step, + total_step=total_step, ) negative_noise_pred = self.predict_noise( - latents, - timestep, - negative_prompt_emb, - negative_add_text_embeds, - image_emb, - image_ids, - text_ids, - guidance, - controlnet_params, - current_step, - total_step, + latents=latents, + timestep=timestep, + prompt_emb=negative_prompt_emb, + add_text_embeds=negative_add_text_embeds, + image_emb=image_emb, + image_ids=image_ids, + text_ids=text_ids, + guidance=guidance, + controlnet_params=controlnet_params, + current_step=current_step, + total_step=total_step, ) noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred) return noise_pred @@ -692,17 +694,17 @@ def predict_noise_with_cfg( text_ids = torch.cat([text_ids, text_ids], dim=0) guidance = torch.cat([guidance, guidance], dim=0) positive_noise_pred, negative_noise_pred = self.predict_noise( - latents, - timestep, - prompt_emb, - add_text_embeds, - image_emb, - image_ids, - text_ids, - guidance, - controlnet_params, - current_step, - total_step, + latents=latents, + timestep=timestep, + prompt_emb=prompt_emb, + add_text_embeds=add_text_embeds, + image_emb=image_emb, + image_ids=image_ids, + text_ids=text_ids, + guidance=guidance, + controlnet_params=controlnet_params, + current_step=current_step, + total_step=total_step, ) noise_pred = negative_noise_pred + cfg_scale * (positive_noise_pred - negative_noise_pred) return noise_pred @@ -721,30 +723,39 @@ def predict_noise( current_step: int, total_step: int, ): - origin_latents_shape = latents.shape - if self.config.control_type != ControlType.normal: + height, width = latents.shape[2:] + latents = self.dit.patchify(latents) + image_seq_len = latents.shape[1] + + double_block_output, single_block_output = None, None + if self.config.control_type == ControlType.normal: + double_block_output, single_block_output = self.predict_multicontrolnet( + latents=latents, + timestep=timestep, + prompt_emb=prompt_emb, + add_text_embeds=add_text_embeds, + guidance=guidance, + text_ids=text_ids, + image_ids=image_ids, + controlnet_params=controlnet_params, + current_step=current_step, + total_step=total_step, + ) + elif self.config.control_type == ControlType.bfl_kontext: + for idx, controlnet_param in enumerate(controlnet_params): + control_latents = controlnet_param.image * controlnet_param.scale + control_image_ids = self.dit.prepare_image_ids(control_latents) + control_image_ids[..., 0] = idx + 1 + control_latents = self.dit.patchify(control_latents) + latents = torch.cat((latents, control_latents), dim=1) + image_ids = torch.cat((image_ids, control_image_ids), dim=1) + else: controlnet_param = controlnet_params[0] - if self.config.control_type == ControlType.bfl_kontext: - latents = torch.cat((latents, controlnet_param.image * controlnet_param.scale), dim=2) - image_ids = image_ids.repeat(1, 2, 1) - image_ids[:, image_ids.shape[1] // 2 :, 0] += 1 - else: - latents = torch.cat((latents, controlnet_param.image * controlnet_param.scale), dim=1) - latents = latents.to(self.dtype) - controlnet_params = [] + control_latents = controlnet_param.image * controlnet_param.scale + control_latents = self.dit.patchify(control_latents) + latents = torch.cat((latents, control_latents), dim=2) - double_block_output, single_block_output = self.predict_multicontrolnet( - latents=latents, - timestep=timestep, - prompt_emb=prompt_emb, - add_text_embeds=add_text_embeds, - guidance=guidance, - text_ids=text_ids, - image_ids=image_ids, - controlnet_params=controlnet_params, - current_step=current_step, - total_step=total_step, - ) + latents = latents.to(self.dtype) self.load_models_to_device(["dit"]) noise_pred = self.dit( @@ -759,8 +770,8 @@ def predict_noise( controlnet_double_block_output=double_block_output, controlnet_single_block_output=single_block_output, ) - if self.config.control_type == ControlType.bfl_kontext: - noise_pred = noise_pred[:, :, : origin_latents_shape[2], : origin_latents_shape[3]] + noise_pred = noise_pred[:, :image_seq_len] + noise_pred = self.dit.unpatchify(noise_pred, height, width) return noise_pred def prepare_latents( @@ -782,7 +793,7 @@ def prepare_latents( sigma_start, sigmas = sigmas[t_start - 1], sigmas[t_start - 1 :] timesteps = timesteps[t_start - 1 :] noise = latents - image = self.preprocess_image(input_image).to(device=self.device, dtype=self.dtype) + image = self.preprocess_image(input_image).to(device=self.device) latents = self.encode_image(image) init_latents = latents.clone() latents = self.sampler.add_noise(latents, noise, sigma_start) @@ -804,15 +815,21 @@ def prepare_latents( def prepare_masked_latent(self, image: Image.Image, mask: Image.Image | None, height: int, width: int): self.load_models_to_device(["vae_encoder"]) if mask is None: + if self.config.control_type == ControlType.bfl_kontext: + width, height = image.size + aspect_ratio = width / height + # Kontext is trained on specific resolutions, using one of them is recommended + _, width, height = min((abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS) + width, height = 16 * (width // 16), 16 * (height // 16) image = image.resize((width, height)) - image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype) + image = self.preprocess_image(image).to(device=self.device) latent = self.encode_image(image) else: if self.config.control_type == ControlType.normal: image = image.resize((width, height)) mask = mask.resize((width, height)) - image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype) - mask = self.preprocess_mask(mask).to(device=self.device, dtype=self.dtype) + image = self.preprocess_image(image).to(device=self.device) + mask = self.preprocess_mask(mask).to(device=self.device) masked_image = image.clone() masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1 latent = self.encode_image(masked_image) @@ -822,8 +839,8 @@ def prepare_masked_latent(self, image: Image.Image, mask: Image.Image | None, he elif self.config.control_type == ControlType.bfl_fill: image = image.resize((width, height)) mask = mask.resize((width, height)) - image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype) - mask = self.preprocess_mask(mask).to(device=self.device, dtype=self.dtype) + image = self.preprocess_image(image).to(device=self.device) + mask = self.preprocess_mask(mask).to(device=self.device) image = image * (1 - mask) image = self.encode_image(image) mask = rearrange(mask, "b 1 (h ph) (w pw) -> b (ph pw) h w", ph=8, pw=8) @@ -862,6 +879,7 @@ def predict_multicontrolnet( if len(controlnet_params) > 0: self.load_models_to_device([]) for param in controlnet_params: + control_condition = param.model.patchify(param.image) current_scale = param.scale if not ( current_step >= param.control_start * total_step and current_step <= param.control_end * total_step @@ -873,15 +891,15 @@ def predict_multicontrolnet( empty_cache() param.model.to(self.device) double_block_output, single_block_output = param.model( - latents, - param.image, - current_scale, - timestep, - prompt_emb, - add_text_embeds, - guidance, - image_ids, - text_ids, + hidden_states=latents, + control_condition=control_condition, + control_scale=current_scale, + timestep=timestep, + prompt_emb=prompt_emb, + pooled_prompt_emb=add_text_embeds, + image_ids=image_ids, + text_ids=text_ids, + guidance=guidance, ) if self.offload_mode is not None: param.model.to("cpu") @@ -927,8 +945,10 @@ def __call__( self.dit.refresh_cache_status(num_inference_steps) if not isinstance(controlnet_params, list): controlnet_params = [controlnet_params] - if self.config.control_type != ControlType.normal: - assert controlnet_params and len(controlnet_params) == 1, "bfl_controlnet must have one controlnet" + if self.config.control_type in [ControlType.bfl_control, ControlType.bfl_fill]: + assert controlnet_params and len(controlnet_params) == 1, ( + "bfl_controlnet or bfl_fill must have one controlnet" + ) if input_image is not None: width, height = input_image.size