import torch
import os
import torch.distributed as dist
import numpy as np
import imageio
import json

from plugins import GroupNormPlugin, ConvLayerPlugin, AttentionPlugin, LayerNormPlugin, Conv3DPligin, ModulePlugin, UNetPlugin


def export_to_video(video_frames, output_video_path, fps = 12):
    # Ensure all frames are NumPy arrays and determine video dimensions from the first frame
    assert all(isinstance(frame, np.ndarray) for frame in video_frames), "All video frames must be NumPy arrays."

    # Create a video file at the specified path and write frames to it
    with imageio.get_writer(output_video_path, fps=fps, format='mp4') as writer:
        for frame in video_frames:
            writer.append_data(
                (frame * 255).astype(np.uint8)
            )

def save_generation(video_frames, configs, base_path, file_name=None):
    if not os.path.exists(base_path):
        os.makedirs(base_path)
    p_config = configs["pipe_configs"]
    frames, steps, fps = p_config["num_frames"], p_config["steps"], p_config["fps"]
    if not file_name:
        index = [int(each.split('_')[0]) for each in os.listdir(base_path)]
        max_idex = max(index) if index else 0
        idx_str = str(max_idex + 1).zfill(6)


        key_info = '_'.join([str(frames), str(steps), str(fps)])
        file_name = f'{idx_str}_{key_info}'

    with open(f'{base_path}/{file_name}.json', 'w') as f:
        json.dump(configs, f, indent=4)

    export_to_video(video_frames, f'{base_path}/{file_name}.mp4', fps=p_config["export_fps"])

    return file_name



class GlobalState:
    def __init__(self, state={}) -> None:
        self.init_state(state)
    
    def init_state(self, state={}):
        self.state = state

    def set(self, key, value):
        self.state[key] = value

    def get(self, key, default=None):
        return self.state.get(key, default)



class DistController(object):
    def __init__(self, rank, world_size, config) -> None:
        super().__init__()
        self.rank = rank
        self.world_size = world_size
        self.config = config
        self.is_master = rank == 0
        self.init_dist()
        self.init_group()
        self.device = torch.device(f"cuda:{config['devices'][dist.get_rank()]}")

    def init_dist(self):
        print(f"Rank {self.rank} is running.")
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = str(self.config.get("master_port") or "29500")
        dist.init_process_group("nccl", rank=self.rank, world_size=self.world_size)

    def init_group(self):
        self.adj_groups = [dist.new_group([i, i+1]) for i in range(self.world_size-1)]

class DistWrapper(object):
    def __init__(self, pipe, dist_controller: DistController, config) -> None:
        super().__init__()
        self.pipe = pipe
        self.dist_controller = dist_controller
        self.config = config
        self.global_state = GlobalState({
            "dist_controller": dist_controller
        })
        self.plugin_mount()

    def switch_plugin(self, plugin_name, enable):
        if plugin_name not in self.plugins: return
        for moudule_id in self.plugins[plugin_name]:
            moudle: ModulePlugin = self.plugins[plugin_name][moudule_id]
            moudle.set_enable(enable)
    
    def config_plugin(self, plugin_name, config):
        if plugin_name not in self.plugins: return
        for moudule_id in self.plugins[plugin_name]:
            moudle: ModulePlugin = self.plugins[plugin_name][moudule_id]
            moudle.update_config(config)

    
    def plugin_mount(self):
        self.plugins = {}
        self.unet_plugin_mount()
        self.group_norm_plugin_mount()
        self.attn_plugin_mount()
        self.conv_3d_plugin_mount()

        # Conv3d and Conv layer can only be used one at a time
        # self.conv_plugin_mount()
        # self.layer_norm_plugin_mount()

    def group_norm_plugin_mount(self):
        self.plugins['group_norm'] = {}
        group_norms = []
        for module in self.pipe.unet.named_modules():
            if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'GroupNorm':
                group_norms.append(module[1])
        print(f'Found {len(group_norms)} group norms')
        for i, group_norm in enumerate(group_norms):
            plugin_id = 'group_norm', i
            self.plugins['group_norm'][plugin_id] = GroupNormPlugin(group_norm, plugin_id, self.global_state)
            
    def layer_norm_plugin_mount(self):
        self.plugins['layer_norm'] = {}
        layer_norms = []
        for module in self.pipe.unet.named_modules():
            if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'LayerNorm':
                layer_norms.append(module[1])
        print(f'Found {len(layer_norms)} layer norms')
        for i, layer_norm in enumerate(layer_norms):
            plugin_id = 'layer_norm', i
            self.plugins['layer_norm'][plugin_id] = LayerNormPlugin(layer_norm, plugin_id, self.global_state)

    def conv_plugin_mount(self):
        self.plugins['conv_layer'] = {}
        convs = []
        for module in self.pipe.unet.named_modules():
            if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'TemporalConvLayer':
                convs.append(module[1])
        print(f'Found {len(convs)} convs')
        for i, conv in enumerate(convs):
            plugin_id = 'conv_layer', i
            self.plugins['conv_layer'][plugin_id] = ConvLayerPlugin(conv, plugin_id, self.global_state)

    def conv_3d_plugin_mount(self):
        self.plugins['conv_3d'] = {}
        conv3d_s = []
        for module in self.pipe.unet.named_modules():
            if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'Conv3d':
                conv3d_s.append(module[1])
        print(f'Found {len(conv3d_s)} conv3d_s')
        for i, conv in enumerate(conv3d_s):
            plugin_id = 'conv_3d', i
            self.plugins['conv_3d'][plugin_id] = Conv3DPligin(conv, plugin_id, self.global_state)


    def attn_plugin_mount(self):
        self.plugins['attn'] = {}
        attns = []
        for module in self.pipe.unet.named_modules():
            if ('temp_' in module[0] or 'transformer_in' in module[0]) and module[1].__class__.__name__ == 'Attention':
                attns.append(module[1])
        print(f'Found {len(attns)} attns')
        for i, attn in enumerate(attns):
            plugin_id = 'attn', i
            self.plugins['attn'][plugin_id] = AttentionPlugin(attn, plugin_id, self.global_state)

    def unet_plugin_mount(self):
        self.plugins['unet'] = UNetPlugin(
            self.pipe.unet,
            ('unet', 0),
            self.global_state
        )
    
    def inference(
        self,
        prompts="A beagle wearning diving goggles  swimming in the ocean while the camera is moving, coral reefs in the background",
        pipe_configs={
            "steps": 50,
            "guidance_scale": 12,
            "fps": 60,
            "num_frames": 24 * 1,
            "height": 320,
            "width": 512,
            "export_fps": 12,
            "base_path": "./work/output",
            "file_name": None
        },
        plugin_configs={
            "attn":{
                "padding": 24,
                "top_k": 24,
                "top_k_chunk_size": 24,
                "attn_scale": 1.,
                "token_num_scale": True,

            },
            "conv_3d": {
                "padding": 1,
            }, 
            "conv_layer": {},
        },
        additional_info={},
    ):
        self.plugin_mount()
        generator = torch.Generator("cuda").manual_seed(self.config["seed"] + self.dist_controller.rank)
        # generator = torch.Generator("cuda").manual_seed(self.config["seed"])

        # for plugin_name in plugin_configs:
        #     if not plugin_configs[plugin_name]: continue
        #     self.config_plugin(plugin_name, plugin_configs[plugin_name])
        self.global_state.set("plugin_configs", plugin_configs)

        video_frames = self.pipe(
            prompts, 
            num_inference_steps=pipe_configs["steps"], 
            guidance_scale=pipe_configs["guidance_scale"],
            height=pipe_configs['height'], 
            width=pipe_configs['width'], 
            num_frames=pipe_configs['num_frames'], 
            fps=pipe_configs['fps'],
            generator=generator
        ).frames[0]

        video_frames = torch.tensor(video_frames, dtype=torch.float16, device=self.dist_controller.device)

        print(f"Rank {self.dist_controller.rank} finished inference. Result: {video_frames.shape}")
        all_frames = [
            torch.zeros_like(video_frames, dtype=torch.float16) for _ in range(self.dist_controller.world_size)
        ] if self.dist_controller.is_master else None
        dist.gather(video_frames, all_frames, dst=0)
        if self.dist_controller.is_master:
            all_frames = torch.cat(all_frames, dim=0).cpu().numpy()
            save_generation(
                all_frames, 
                {
                    "prompt": prompts,
                    "pipe_configs": pipe_configs,
                    "plugin_configs": plugin_configs,
                    "additional_info": additional_info
                },
                pipe_configs["base_path"],
                pipe_configs["file_name"]
            )

    
    def inference_dynamic(
        self,
        prompts="A beagle wearning diving goggles  swimming in the ocean while the camera is moving, coral reefs in the background",
        pipe_configs={
            "steps": 30,
            "guidance_scale": 2,
            "fps": 60,
            "num_frames": 16 * 1,
            "height": 320,
            "width": 512,
            "export_fps": 12,
            "base_path": "./work/output",
            "file_name": None
        },
        plugin_configs={
            "attn":{
                "padding": 8,
                "top_k": 8,
                "top_k_chunk_size": 8,
                "attn_scale": 1.,
                "token_num_scale": True,

            },
            "conv_3d": {
                "padding": 1,
            }, 
            "conv_layer": {},
        },
    ):
        generator = torch.Generator("cuda").manual_seed(self.config["seed"] + self.dist_controller.rank)

        self.global_state.set("plugin_configs", plugin_configs)

        from PIL import Image
        
        sid, eid = self.dist_controller.rank*24, (self.dist_controller.rank+1)*24
        sid, eid = str(sid).zfill(4), str(eid).zfill(4)

        print(f"Rank {self.dist_controller.rank} is running. Start frame: {sid}, End frame: {eid}")

        image = Image.open(f"./extracted_frames/frame_{sid}.jpg")
        end_image = Image.open(f"./extracted_frames/frame_{eid}.jpg")

        print(f"Rank {self.dist_controller.rank} loaded images. Start frame: {image}, End frame: {image}")

        video_frames = self.pipe(
            prompts, 
            num_inference_steps=pipe_configs["steps"], 
            guidance_scale=pipe_configs["guidance_scale"],
            height=pipe_configs['height'], 
            width=pipe_configs['width'], 
            num_frames=pipe_configs['num_frames'], 
            generator=generator,
            
            image=image, 
            end_image=end_image,
            image_guidance_scale=1.5,
        ).frames[0]

        video_frames = torch.tensor(video_frames, dtype=torch.float16, device=self.dist_controller.device)

        video_frames = video_frames[1:]

        print(f"Rank {self.dist_controller.rank} finished inference. Result: {video_frames.shape}")
        all_frames = [
            torch.zeros_like(video_frames, dtype=torch.float16) for _ in range(self.dist_controller.world_size)
        ] if self.dist_controller.is_master else None
        dist.gather(video_frames, all_frames, dst=0)
        if self.dist_controller.is_master:
            all_frames = torch.cat(all_frames, dim=0).cpu().numpy()
            save_generation(
                all_frames, 
                {
                    "prompt": prompts,
                    "pipe_configs": pipe_configs,
                    "plugin_configs": plugin_configs,
                },
                pipe_configs["base_path"],
                pipe_configs["file_name"]
            )

            