The following code is applicable to Qwen/Qwen-Image, Qwen/Qwen-Image-Edit, and black-forest-labs/FLUX.1-Krea-dev.
```python
import torch
import argparse
import os
import numpy as np
import datetime
import random
from diffusers import QwenImageEditPipeline, DiffusionPipeline, FluxPipeline
import gradio as gr
from optimum.quanto import freeze, qint8, quantize
parser = argparse.ArgumentParser()
parser.add_argument("--pipeline", type=int, default=0, help="0 - QwenImageEditPipeline, 1 - DiffusionPipeline, 2 - FluxPipeline")
parser.add_argument("--server_name", type=str, default="127.0.0.1", help="IP address, change to 0.0.0.0 for local network access.")
parser.add_argument("--server_port", type=int, default=7892, help="Port used.")
parser.add_argument("--share", action="store_true", help="Whether to enable gradio sharing.")
parser.add_argument("--mcp_server", action="store_true", help="Whether to enable mcp server.")
parser.add_argument('--vram', type=str, default='high', choices=['low', 'high'], help='Vram mode.')
parser.add_argument('--lora', type=str, default="None", help='Path of lora model.')
args = parser.parse_args()
if
torch.cuda.is_available():
device = "cuda"
if torch.cuda.get_device_capability()[0] >= 8:
dtype = torch.bfloat16
else:
dtype = torch.float16
else:
device = "cpu"
dtype = torch.float32
MAX_SEED = np.iinfo(np.int32).max
os.makedirs("outputs", exist_ok=True)
pipe = None
pipe_text = ""
if args.pipeline == 0:
title_text = "Image Editing"
pipe_text = "Qwen-Image-Edit Image Editing"
model_id = "Qwen/Qwen-Image-Edit"
print(f"Loading {model_id}")
pipe = QwenImageEditPipeline.from_pretrained(model_id, torch_dtype=dtype)
pipe.set_progress_bar_config(disable=None)
elif args.pipeline == 1:
title_text = "Text-to-Image"
pipe_text = "Qwen-Image Text-to-Image"
model_id = "Qwen/Qwen-Image"
print(f"Loading {model_id}")
pipe = DiffusionPipeline.from_pretrained(model_id, torch_dtype=dtype)
elif args.pipeline == 2:
title_text = "Text-to-Image"
pipe_text = "FLUX.1-Krea-dev Text-to-Image"
model_id = "black-forest-labs/FLUX.1-Krea-dev"
print(f"Loading {model_id}")
pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype)
if args.lora != "None":
pipe.load_lora_weights(args.lora)
print(f"Loaded {args.lora}")
if args.vram == "high":
pipe.vae.enable_tiling()
#pipe.enable_model_cpu_offload()
pipe.to(device)
else:
quantize(pipe.transformer, qint8)
freeze(pipe.transformer)
pipe.vae.enable_tiling()
pipe.enable_model_cpu_offload()
def generate(
prompt,
negative_prompt,
num_inference_steps,
true_cfg_scale,
seed_param,
image=None,
width=None,
height=None,
):
global pipe
timestamp =
datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
if seed_param < 0:
seed = random.randint(0, MAX_SEED)
else:
seed = seed_param
input_image = image if image is not None else None
if args.pipeline == 0:
result = pipe(
image=input_image,
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
true_cfg_scale=true_cfg_scale,
generator=torch.Generator().manual_seed(seed)
).images[0]
else:
result = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
width=width,
height=height,
num_inference_steps=num_inference_steps,
true_cfg_scale=true_cfg_scale,
generator=torch.Generator().manual_seed(seed)
).images[0]
output_path = f"outputs/{seed}_{timestamp}.png"
result.save(output_path)
return output_path, seed
with gr.Blocks(theme=gr.themes.Base()) as demo:
gr.Markdown(
f"""
<div>
<h2 style="font-size: 30px;text-align: center;">{title_text}</h2>
</div>
"""
)
with gr.TabItem(pipe_text):
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Input image (optional)", type="pil", visible=args.pipeline == 0)
prompt = gr.Textbox(label="Prompt", value="ultra clear, 4K, film-like composition, ")
negative_prompt = gr.Textbox(label="Negative prompt", value="")
width = gr.Slider(label="Width (recommends: 1328x1328, 1664x928, 1472x1140)", minimum=256, maximum=2656, step=32, value=1328, visible=args.pipeline > 0)
height = gr.Slider(label="Height", minimum=256, maximum=2656, step=32, value=1328, visible=args.pipeline > 0)
num_inference_steps = gr.Slider(label="Sampling steps", minimum=1, maximum=100, step=1, value=50)
true_cfg_scale = gr.Slider(label="true cfg scale", minimum=1, maximum=10, step=0.1, value=4.0)
seed_param = gr.Number(label="Seed (please enter a positive integer, -1 for random)", value=-1)
generate_button = gr.Button("🎬 Start generation", variant='primary')
with gr.Column():
image_output = gr.Image(label="Generated image")
seed_output = gr.Textbox(label="Seed")
gr.on(
triggers=[generate_button.click, prompt.submit, negative_prompt.submit],
fn=generate,
inputs=[
prompt,
negative_prompt,
num_inference_steps,
true_cfg_scale,
seed_param,
image_input,
width,
height,
],
outputs=[image_output, seed_output]
)
if __name__ == "__main__":
demo.launch(
server_name=args.server_name,
server_port=args.server_port,
share=args.share,
mcp_server=args.mcp_server,
inbrowser=True
)
```