import torch
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import time
import random
import os

from src.videoCrafter import TextToVideoVideoCrafterPipeline, UNet3DVideoCrafterConditionModel
from diffusers.schedulers import DPMSolverMultistepScheduler

from tools import export_to_video, mount_attn
from dist_tools import DistWrapper, DistController


def init_pipeline(config):
    pipe = TextToVideoVideoCrafterPipeline.from_pretrained("cerspense/zeroscope_v2_576w", torch_dtype=torch.float16)
    pipe.unet = UNet3DVideoCrafterConditionModel.from_pretrained('./base_512_v2/diffusers', torch_dtype=torch.float16)
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras=True, algorithm_type="sde-dpmsolver++")
    pipe.enable_model_cpu_offload(
        gpu_id=config["devices"][dist.get_rank() % len(config["devices"])],
    )
    pipe.enable_vae_slicing()
    return pipe

def run_inference(rank, world_size, data, config):
    dist_controller = DistController(rank, world_size, config)
    pipe = init_pipeline(config)
    dist_pipe = DistWrapper(pipe, dist_controller, config)
    start = time.time()

    pipe_configs={
        "steps": 30,
        "guidance_scale": 12,
        "fps": 24,
        "num_frames": 24 * 1,
        "height": 320,
        "width": 512,
        "export_fps": 12,
        "base_path": config["base_path"],
        "file_name": None
    }
    plugin_configs={
        "attn":{
            "padding": 8,
            "top_k": 16,
            "top_k_chunk_size": 24,
            "attn_scale": 1.2,
            "token_num_scale": False,
            "dynamic_scale": True,
            "dynamic_attn_tokens": False,
        },
        "conv_3d": {
            "padding": 1,
        }, 
        "conv_layer": {},
    }

    supported = [
        "temporal_flickering",
        "human_action",
        "subject_consistency",
        "dynamic_degree",
        "motion_smoothness",
        "appearance_style",
        "temporal_style",
        "overall_consistency",
        "imaging_quality",
        "scene",
        "background_consistency",
        "aesthetic_quality", # out of memory
    ]
    supported = set(supported)


    # for (i, prompt) in enumerate(prompts):
    for (i, item) in data:
        prompt = item['prompt_en']
        dimensions = set(item['dimension'])
        if 'work' not in pipe_configs["base_path"]:
            pipe_configs["file_name"] = str(i + 1).zfill(4)
        print(f"Rank {rank} processing {i + 1}th item\nprompt: {prompt}\nfile_name: {pipe_configs['file_name']}")
        # dimension = set(item['dimension'])
        
        if not prompt: continue
        if supported.isdisjoint(dimensions): continue
        # if 'temporal_flickering' not in dimensions: continue
        # if i + 1 in [int(each.split('.')[0]) for each in os.listdir(pipe_configs["base_path"])]: continue
        if os.path.exists(f"{pipe_configs['base_path']}/{pipe_configs['file_name']}.mp4"): continue

        # prompt = "A rabbit dressed as a magician performing tricks in a children's party."
        # prompt = "A beagle wearning diving goggles  swimming in the ocean while the camera is moving, coral reefs in the background"
        # prompt = "Two golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in."
        # prompt = "[Film like, 4k, f/1.8] A golden retriever puppy playing with a frisbee in a cluttered backyard, with toys scattered around, plants swaying, and birds occasionally flying by. The puppy jumps and trots around, interacting with its surroundings."
        # prompt = "In a still frame, a stop sign"
        # prompt = "[Film like, 4k, f/1.8] A golden retriever puppy playing with a frisbee in a cluttered backyard, with toys scattered around, plants swaying, and birds occasionally flying by. The puppy jumps and trots around, interacting with its surroundings."
        prompt = "[Film like, 4k, f/1.8] Santa Claus dancing joyfully in a decorated town square during a holiday festival, with colorful lights, market stalls, and families enjoying the festivities. Santa's movements are rhythmic and slow, while the lively background adds to the festive atmosphere."
        # prompt = "A beagle wearning diving goggles  swimming in the ocean while the camera is moving, coral reefs in the background"
        # prompt = "[Film like, 4k, f/1.8] A musician wearing sunglasses playing a guitar on a busy street corner with people passing by."
        # prompt = "A squirrel crafting intricate wood carvings in a forest workshop."
        # prompt = "A robot assembling parts in a high-tech futuristic factory."
        prompt = "[Film like, 4k, f/1.8] Two golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in."
        # if dist.get_rank()  == 0:
        #     prompt = "Ukiyoe; A panda is eating hamburgers. Background is a window with a view of the city."
        # else:
        #     prompt = "Realistic; A panda is eating hamburgers. Background is a window with a view of the city."
        # prompt = "[Film like, 4k, f/1.8] A rabbit dressed a red jacket dancing with children joyfully in a children's party."
        prompts = [
            # "A rabbit dressed as a magician performing tricks in a children's party.",
            # "A beagle wearning diving goggles  swimming in the ocean while the camera is moving, coral reefs in the background",
            # "Two golden retriever puppies playing in the snow. Their heads pop out of the snow, covered in.",
            # "[Film like, 4k, f/1.8] A golden retriever puppy playing with a frisbee in a cluttered backyard, with toys scattered around, plants swaying, and birds occasionally flying by. The puppy jumps and trots around, interacting with its surroundings.",
            # "In a still frame, a stop sign",
            # "A little"
        ]
        object_name = "A golden retriever"
        prompts = [
            f"{object_name} is missing in a raining night. On the road, with a street light, a car is passing by.",
            f"A boy found a {object_name} in a raining night in black hoodie. Looking at the {object_name}",
            f"A {object_name} is following a boy in black hoodie in a raining night.",
            f"A {object_name} playing with a boy in black hoodie in the yard.",
        ]

        object_name = "A loyal golden retriever"
        prompts = [
            f"{object_name}  is waiting his owner back home, in the morning",
            f"{object_name}  is waiting his owner back home, in the rainy day",
            f"{object_name}  is waiting his owner back home, in the snowy day",
            f"{object_name}  is waiting his owner back home, in the night with street light",
        ]




        index = int(dist.get_rank()/ dist.get_world_size() * len(prompts))
        prompt = prompts[index]
        print(f"Rank {rank} processing {i + 1}th item\nprompt: {prompt}\nfile_name: {pipe_configs['file_name']}")
        
        # prompt = "[Ukiyoe Style; comic style;]" + prompt
        prompt = "[Black and white Manga style;]" + prompt
        
        start = time.time()
        dist_pipe.inference(
            prompt,
            pipe_configs,
            plugin_configs,
            additional_info={
                "vbench_data": item
            }
        )
        print(f"Rank {rank} finished. Time: {time.time() - start}")
        break



def main(config):
    size = len(config["devices"])
    processes = []

    with open('./vbench/VBench_full_info.json') as f:
        import json
        data = json.load(f)


    if not os.path.exists(config["base_path"]):
        os.makedirs(config["base_path"])

    data = [(i, item) for i, item in enumerate(data)]
    random.shuffle(data)

    for rank, _ in enumerate(config["devices"]):
        p = mp.Process(target=run_inference, args=(rank, size, data, config))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

example_config = {
    "dtype": torch.float16,
    "devices": [6, 7],
    "seed": 123,
    "master_port": 29516,
    "base_path": "./work/nnn_exp",
}


if __name__ == "__main__":
    mp.set_start_method("spawn")
    port_bias = random.randint(0, 100)
    configs = [
        # {
        #     **example_config,
        #     # "devices": [0, 1,2,3],
        #     # "devices": [0, 1,2,3,4,5,6,7],
        #     "devices": [0,1],
        #     "master_port": 29516,
        #     "seed": 11,
        # },
        # {
        #     **example_config,
        #     "devices": [2, 3],
        #     "master_port": 29517,
        #     "seed": 11,
        # },
        # {
        #     **example_config,
        #     "devices": [4, 5],
        #     "master_port": 29518,
        #     "seed": 11,
        # },
        {
            **example_config,
            "devices": [0,1],
            "devices": [4,5, 6, 7],
            # "devices": [2,3,4,5],
            "devices": [0,1,2,3,4,5,6,7],
            # "devices": [0,1,2,3],
            # "devices": [0, 1],
            # "devices": [2, 3],
            # "devices": [4, 5],
            # "devices": [6, 7],
            "master_port": 62400 + port_bias,
            # "base_path": "./vbench/final5_64frames",
            "seed": 113,
        },
    ]
    # main(example_config)
    for config in configs:
        main(config)
        # break