import torch
from diffusers import AnimateDiffPipeline, DDIMScheduler, MotionAdapter
from diffusers.utils import export_to_gif
import torch.distributed as dist
import torch.multiprocessing as mp
import time
import os
from diffusers.utils import load_image, export_to_video



def run_inference(rank, world_size, config):
    print(f"Rank {rank} is running.")

    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29501'

    dist.init_process_group("nccl", rank=rank, world_size=world_size)

    if config["strategy"] == "n3s2":
        from animatediff.MoTiAsync_4 import ParallelDiffusionWorker
    elif config["strategy"] == "n2s1":
        from animatediff.TimeAsync_2 import ParallelDiffusionWorker
    elif config["strategy"] == "n3s1":
        from animatediff.ModelAsync_3 import ParallelDiffusionWorker
    elif config["strategy"]== "n2s2":
        from animatediff.TimeAsync_3 import ParallelDiffusionWorker


    pipeline = ParallelDiffusionWorker(
        config,
    )
    torch.cuda.set_device(
        config["devices"][rank]
    )  # this is necessary to make sure the correct device is set
    
    # warm up
    torch.manual_seed(config["seed"])
    torch.cuda.manual_seed_all(config["seed"])
    pipeline.reset_state()
    frames = pipeline(
            prompt="Brilliant fireworks, high quality",
            negative_prompt="bad quality, worse quality, low resolution",
            num_frames=16,
            guidance_scale=7.5,
            num_inference_steps=config["step"],
        ).frames[0]

    for i in range(1):
        pipeline.reset_state(warm_up=config["warm_up"])
        torch.manual_seed(config["seed"])
        torch.cuda.manual_seed_all(config["seed"])
        start = time.time()
        frames = pipeline(
            prompt="Brilliant fireworks, high quality",
            negative_prompt="bad quality, worse quality, low resolution",
            num_frames=16,
            guidance_scale=7.5,
            num_inference_steps=config["step"],
        ).frames[0]
        print(f"Rank {rank} Time taken: {time.time()-start:.2f} seconds.")
    export_to_gif(frames, "animation_async_rank{}.gif".format(rank))



if __name__ == "__main__":
    config = {
    "model_name": "emilianJR/epiCRealism",
    "dtype": torch.float16,
    "strategy":"n2s1",
    "devices": ["cuda:0","cuda:1"],
    "seed": 20,
    "step": 50,
    "time_shift":False,
    "warm_up":3,
    }

    size = len(config["devices"])
    processes = []
    try:
        mp.set_start_method('spawn', force=True)
        print("spawned")
    except RuntimeError:
        pass

    for rank in range(size):
        p = mp.Process(target=run_inference, args=(rank, size, config))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()