import os
import json
import torch

from copy import deepcopy
from glob import glob

from latent_diffusion.diffusion.ddpm import LatentDiffusion
from latent_diffusion.diffusion.ddim import DDIMSampler
from utils import toDevice

class LatentDiffusionWraper():
    def __init__(self, name:str, config:dict, logger) -> None:
        self.name = name
        self.device = config['device']
        self.logger = logger
        
        model_config = json.load(open(config["model_config"], "r"))
        self.diffusion = LatentDiffusion(**model_config["model"])
        self.init(model_config)
        self.diffusion.eval()
        self.diffusion.convert_to_fp16()
        self.diffusion = self.diffusion.to(self.device)
        self.data_type = self.diffusion.data_type
        self.print_model()
        
        self.sampler = DDIMSampler(model=self.diffusion, **model_config["ddim"])

        self.finetuned_names, self.finetuned_params = self.get_finetune_layers(config["finetune_layers"])

        for n, param in self.diffusion.cond_stage_model.transformer.named_parameters():
            param.requires_grad = True
        
        self.original_unet = deepcopy(self.diffusion.model)
        
    def init(self, config: dict):
        assert "pretrained_model" in config.keys()
        if os.path.isfile(config['pretrained_model']):
            state_dict = torch.load(config['pretrained_model'], map_location=torch.device('cpu'))
            missing, unexpected = self.diffusion.load_state_dict(state_dict=state_dict, strict=False)
            self.logger.info("LatentDiffusionWraper: Load pretrained model success!")
            self.logger.info("LatentDiffusionWraper: Load the checkpoint {}".format(config['pretrained_model']))
            self.logger.info("LatentDiffusionWraper: Miss Module {}".format(missing))
            self.logger.info("LatentDiffusionWraper: Unexpect Module {}".format(unexpected))
        else:
            self.logger.info("LatentDiffusionWraper: Load pretrained model fail! ({})".format(config['pretrained_model']))
            raise FileNotFoundError
    
    def save_checkpoint(self, mp_trainer, epoch, path):
        if not os.path.exists(os.path.join(path, "checkpoints")):
            os.makedirs(os.path.join(path, "checkpoints"))
        # state_dict = mp_trainer.master_params_to_state_dict(mp_trainer.master_params)
        state_dict = self.diffusion.state_dict()
        filename = os.path.join(path, "checkpoints", self.name.replace(" ", "_") + "_epoch_{}.pth".format(epoch))
        torch.save(state_dict, filename)

    def print_model(self):
        for name, parameters in self.diffusion.named_parameters():
            self.logger.info("LatentDiffusionWraper: " + name + " : {}".format(parameters.size()))

    def set_finetune_layers(self, layer_name_list: list):
        if layer_name_list.__len__():
            for name, parameters in self.diffusion.named_parameters():
                flag = False
                for word in layer_name_list:
                    if name == word:
                        flag = True
                parameters.requires_grad = flag
                if flag:
                    self.logger.info("LatentDiffusionWraper: " + name + " requires grad {}!".format(flag))
        else:
            for name, parameters in self.diffusion.named_parameters():
                parameters.requires_grad = False
            self.logger.info("LatentDiffusionWraper: No parameters are finetuned!")
    
    def get_finetune_layers(self, layer_name_list: list):
        names, ret = [], []
        if layer_name_list.__len__():
            for name, parameters in self.diffusion.named_parameters():
                flag = False
                for word in layer_name_list:
                    if name == word:
                        flag = True
                if flag:
                    names.append(name)
                    ret.append(parameters)
                    self.logger.info("LatentDiffusionWraper: " + name + " requires grad {}!".format(flag))
        else:
            self.logger.info("LatentDiffusionWraper: No parameters are finetuned!")
        return names, ret
    
    def parameters(self):
        return self.diffusion.parameters()
    
    def train(self, cond, pival_cond, prompt_manager, batch_size=2, timestep_range=[0, 1000], guidance_range=[1, 15], **kwargs):  # for unet learning
        cond = cond * batch_size
        shape = (batch_size, self.diffusion.channels, self.diffusion.image_size, self.diffusion.image_size)

        t_sample = torch.randint(
                int(timestep_range[0]), int(timestep_range[1]), 
                (1,)
            ).item()
        t_max_ddim = t_sample // (self.diffusion.num_timesteps // self.sampler.S)
        t_max_ddpm = t_max_ddim*(self.diffusion.num_timesteps // self.sampler.S)
        t_nxt = torch.ones((1,), device=self.device).long() * (t_sample)

        guidance_sample = torch.rand((1,)).item() * (guidance_range[1] - guidance_range[0]) + guidance_range[0]

        with torch.no_grad():
            new_unet = self.diffusion.model
            self.diffusion.model = self.original_unet

            # random sample
            if hasattr(prompt_manager, "sampled_pool"):
                if torch.rand((1,)) < 0.7 or "*" not in prompt_manager.sampled_pool.keys():
                    sample = -1
                else:
                    sample = torch.randint(0, prompt_manager.sampled_pool["*"].shape[0], (1,))
            else:
                sample = -1
                
            # text guidance
            c_zero = self.diffusion.get_learned_conditioning(cond, prompt_manager, zero=True)
            c_pos = self.diffusion.get_learned_conditioning(cond, prompt_manager, pos_condition=True, sample=sample)
            c_neg = self.diffusion.get_learned_conditioning(cond, prompt_manager, pos_condition=False, sample=sample, **kwargs)
            c_pival = self.diffusion.get_learned_conditioning(pival_cond, None, pos_condition=True)

            uc = self.diffusion.get_learned_conditioning([""] * batch_size)

            x_t = self.sampler.sample2(
                batch_size, 
                conditioning=c_pos,
                unconditional_conditioning=uc,
                unconditional_guidance_scale=guidance_sample,
                t_end=t_max_ddpm,
                verbose=False,
                )

            x_in4 = torch.cat([x_t] * 4)
            c_in4 = torch.cat([c_zero, c_pos, c_neg, c_pival])
            x_t_1_uc_ori, x_t_1_ori_pos, x_t_1_ori_neg, x_t_1_ori_pival = self.diffusion.apply_model(x_in4, t_nxt, c_in4).chunk(4)

            self.diffusion.model = new_unet
            
        x_t_1_uc, x_t_1_pos, x_t_1_neg, x_t_1_pival = self.diffusion.apply_model(x_in4, t_nxt, c_in4).chunk(4)

        return {"pos": x_t_1_pos, "neg": x_t_1_neg, 
                "pival": x_t_1_ori_pival.detach(),
                "ori_uc": x_t_1_uc_ori.detach(), "uc": x_t_1_uc,
                "ori_neg": x_t_1_ori_neg.detach(), 
                "new_pival": x_t_1_pival,
                }
    

    def eval(self, batch, timestep_range=[0, 1000], prompt_manager=None):
        imgs = batch["img"]
        txt = batch["txt"]
        inputs = {"image": imgs, "txt": txt}
        toDevice(self.device, inputs)

        t = torch.randint(
                int(timestep_range[0]), int(timestep_range[1]), 
                (imgs.shape[0],), 
                device=self.device
            ).long()

        # diffusion shared_step
        self.diffusion.convert_batch_to_dtype(inputs)
        z, c = self.diffusion.get_input_grad_txt(inputs, self.diffusion.first_stage_key, prompt_manager=prompt_manager)
        z, c = z.to(self.diffusion.data_type), c.to(self.diffusion.data_type)
        
        z_0, loss, loss_dict, noise, model_output = self.diffusion(z, c, t=t)

        output = {"z": z, "z_pred": z_0, "noise": noise, "noise_pred": model_output, "loss_dict": loss_dict, "loss": loss}

        return output

    @torch.no_grad()
    def generate(self, y, prompt_manager=None, uc_guidance=7.5, verbose=True, return_z=False):
        uc = self.diffusion.get_learned_conditioning(len(y) * [""])
        c = self.diffusion.get_learned_conditioning(y, prompt_manager)

        sample_fn = self.sampler.sample

        self.sampler.unconditional_guidance_scale = uc_guidance

        z0, _ = sample_fn(
                batch_size=len(y), 
                conditioning=c, 
                unconditional_conditioning=uc,
                verbose=verbose,
            )

        sample = self.diffusion.decode_first_stage(z0.float()) #(-1, 1)

        if return_z:
            return z0, sample
        else:
            return sample

    
    def init_from_log(self, log_path=None, file_path=None):
        if file_path:
            self.init({"pretrained_model": file_path})
            return
        
        files = glob(os.path.join(log_path, "checkpoints", self.name.replace(" ", "_") + "_epoch_*.pth"))
        if len(files) == 0:
            self.logger.info("{} checkpoints file not found error!!!!".format(self.name))
            raise FileNotFoundError
        files = sorted(files, key=lambda x: int(x.split("_epoch_")[-1].split(".pth")[0]))[-1]
        self.init({"pretrained_model": files})
    
    @torch.no_grad()
    def encode_x_sample(self, x):
        x = x.half().to(self.device)
        encoder_posterior = self.diffusion.encode_first_stage(x)
        z = self.diffusion.get_first_stage_encoding(encoder_posterior).detach()
        return z

    @torch.no_grad()
    def encode_x_mean(self, x):
        x = x.half().to(self.device)
        z = self.diffusion.encode_first_stage(x).mean
        return z
    
    def grad_projection(self, loss: torch.Tensor, loss_base: torch.Tensor, optimizer: torch.optim.Optimizer):
        loss_base.backward(retain_graph=True)
        grad_bases = {}
        for name, v in zip(self.finetuned_names, self.finetuned_params):
            grad_bases[name] = v.grad
        optimizer.zero_grad()

        loss.backward(retain_graph=True)
        for name, v in zip(self.finetuned_names, self.finetuned_params):
            base = grad_bases[name]
            norm_base = base.reshape(-1) / base.reshape(-1).norm()

            if base.abs().sum() > 1e-10:
                shape = v.grad.shape
                grad = v.grad.reshape(-1)
                grad = grad - (grad @ norm_base) * norm_base
                v.grad = grad.reshape(shape) + base
        
    def to_train(self, verbose=True):
        for n, m in self.diffusion.named_modules():
            if hasattr(m, "lora_A") or hasattr(m, "lora_B"):
                m.train()
                if verbose:
                    self.logger.info("LatentDiffusionWraper: Module {} is to train!".format(n))
        if not verbose:
            self.logger.info("LatentDiffusionWraper: Lora modules are to train!")
    
    def to_eval(self):
        self.diffusion.eval()
        self.logger.info("LatentDiffusionWraper: All modules are to eval!")
    
    @torch.no_grad()
    def ema_update(self, ratio=0.8):
        if not hasattr(self, "ema_params"):
            self.ema_params = {}
            for name, params in zip(self.finetuned_names, self.finetuned_params):
                self.ema_params[name] = params.clone().detach()
        else:
            for name, params in zip(self.finetuned_names, self.finetuned_params):
                self.ema_params[name].data = ratio * self.ema_params[name].data + (1 - ratio) * params.data
                params.data = self.ema_params[name].data.clone().detach()
    
    @torch.no_grad()
    def lora_merge_reset(self):
        from loralib import Linear as LoRALinear
        def T(m, w):
            return w.transpose(0, 1) if m.fan_in_fan_out else w
        
        for m in self.diffusion.modules():
            if isinstance(m, LoRALinear):
                m.weight.data += T(m, m.lora_B @ m.lora_A) * m.scaling
                torch.nn.init.zeros_(m.lora_A)
                torch.nn.init.normal_(m.lora_B)

    @torch.no_grad()
    def generate_with_original_unet(self, prompt, prompt_manager=None, verbose=False):
        new_unet = self.diffusion.model
        self.diffusion.model = self.original_unet
        samples = self.generate(prompt, prompt_manager=prompt_manager, verbose=False)
        self.diffusion.model = new_unet
        return samples
    