diff --git a/diffsynth_engine/pipelines/flux_image.py b/diffsynth_engine/pipelines/flux_image.py index e7b2b409..15a25ee6 100644 --- a/diffsynth_engine/pipelines/flux_image.py +++ b/diffsynth_engine/pipelines/flux_image.py @@ -983,8 +983,9 @@ def __call__( elif self.ip_adapter is not None: image_emb = self.ip_adapter.encode_image(ref_image) elif self.redux is not None: - image_prompt_embeds = self.redux(ref_image) - positive_prompt_emb = torch.cat([positive_prompt_emb, image_prompt_embeds], dim=1) + ref_prompt_embeds = self.redux(ref_image) + flattened_ref_emb = ref_prompt_embeds.view(1, -1, ref_prompt_embeds.size(-1)) + positive_prompt_emb = torch.cat([positive_prompt_emb, flattened_ref_emb], dim=1) # Extra input image_ids, text_ids, guidance = self.prepare_extra_input(