import os
import time
import gc
import functools
from contextlib import nullcontext

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

from .sparse_attn import WanAttnProcessor2_0_Preprocessor, WanAttnProcessor2_0_Trainer
import qdiff.flatquant.globalvar as globalvar
from diffusers.models.attention import Attention
from diffusers.models.transformers.transformer_wan import WanAttnProcessor2_0
from .function_utils import set_require_grad_all, get_n_set_parameters_byname, get_paras_dict_by_name, check_params_grad
from .flat_linear import FlatQuantizedLinear
from .function_utils import get_init_scale

save_gpu_memory = False


def set_embed_to_device_wan(model, device):
    model.patch_embedding.to(device)
    model.condition_embedder.to(device)

    
def set_linear_to_ori(layer):
    for name, module in layer.named_modules():
        if isinstance(module, FlatQuantizedLinear):
            module.ori_mode = True

def set_linear_to_normal(layer):
    for name, module in layer.named_modules():
        if isinstance(module, FlatQuantizedLinear):
            module.ori_mode = False

def set_linear_diag_init(layer, alpha):
    for name, module in layer.named_modules():
        if isinstance(module, FlatQuantizedLinear):
            module.trans.diag_scale.data = get_init_scale(module.linear.weight.abs().max(dim=0)[0],  
                                                          torch.ones_like(module.linear.weight.abs().max(dim=0)[0]).cuda() * 1e-5,
                                                          alpha)


def cali_flat_quant_wan(args, model, dataloader, dev, logger):

    def print_memory_stats(prefix=""):
        logger.info(f"{prefix} Memory Stats:")
        logger.info(f"Allocated: {torch.cuda.memory_allocated(dev) / 1024**2:.2f} MB")
        logger.info(f"Reserved/Cached: {torch.cuda.memory_reserved(dev) / 1024**2:.2f} MB")
        logger.info(f"Max Allocated: {torch.cuda.max_memory_allocated(dev) / 1024**2:.2f} MB")
        logger.info(f"Max Reserved: {torch.cuda.max_memory_reserved(dev) / 1024**2:.2f} MB")
        

    print_memory_stats("Initial")
    
    model.cpu()
    model.eval()


    for name, param in model.named_parameters():
        param.requires_grad = False


    if args.deactive_amp:
        dtype = torch.float32
        traincast = nullcontext
    else:
        dtype = torch.bfloat16

        traincast = functools.partial(torch.amp.autocast, device_type="cuda", dtype=dtype)

    layers = model.blocks
    layers[0] = layers[0].to(dev)
    set_embed_to_device_wan(model, dev)

    inps = {
        "hidden_states": [],
    }
    cache = {"i": 0,
             "encoder_hidden_states": [],
             "temb": [],
             'rotary_emb': []}
    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

        def forward(self, *args, **kwargs):

            if save_gpu_memory:
                inps["hidden_states"].append(args[0].to("cpu"))
            else:
                inps["hidden_states"].append(args[0])
            cache["i"] += 1
            cache["encoder_hidden_states"].append(args[1])
            cache["rotary_emb"].append(args[3])
            cache["temb"].append(args[2].float())
            raise ValueError
        
    layers[0] = Catcher(layers[0])
    with torch.no_grad():
        for batch in dataloader:
            if cache["i"] >= args.nsamples:
                break
            try:
                input_dict = {}
                for key in batch.keys():
                    if torch.is_tensor(batch[key]):
                        input_dict[key] = batch[key].to(dev)
                    else:
                        input_dict[key] = batch[key]

                model(**input_dict)
            except ValueError:
                pass

    layers[0] = layers[0].module
    layers[0] = layers[0].cpu()
    set_embed_to_device_wan(model, "cpu")

    del dataloader
    torch.cuda.empty_cache()


    fp_inps = inps
    fp_outs = {
        "hidden_states": [],
    } 
    del inps
    torch.cuda.empty_cache()
    loss_func = torch.nn.MSELoss()

    flat_parameters = {}
    num_train_layer = len(layers)
    
    USE_QUANT_INPUT = False
    USE_DISTILL = True

    if USE_QUANT_INPUT:
        quant_inps = fp_inps
        quant_outs = {
            "hidden_states": [],
        }

    if USE_QUANT_INPUT:
        del fp_inps
        torch.cuda.empty_cache()
    

    for i in range(num_train_layer):

        logger.info(f"========= Layer {i} =========")
        print_memory_stats(f"Before Layer {i}")
        dtype_dict = {}
        layer = layers[i].to(dev)
        for name, param in layer.named_parameters():
            dtype_dict[name] = param.dtype
        with torch.no_grad():
            layer.float()

        set_linear_to_ori(layer)

        if USE_DISTILL:
            
            globalvar.clear_attn_distil()
            for name, module in layer.named_modules():
                if "attn1" in name and isinstance(module, Attention):
                    module.set_processor(WanAttnProcessor2_0_Preprocessor())
            with torch.no_grad():
                for j in range(args.nsamples):
                    try:
                        inputs = (fp_inps["hidden_states"][j].float(), cache["encoder_hidden_states"][j].float(), 
                                cache["temb"][j], cache["rotary_emb"][j])
                        _ = layer(*inputs)
                    except Exception as e:
                        # import ipdb; ipdb.set_trace()
                        continue
            for name, module in layer.named_modules():
                if "attn1" in name and isinstance(module, Attention):
                    module.set_processor(WanAttnProcessor2_0())
        
        with torch.no_grad():
            for j in range(args.nsamples):
                
                if save_gpu_memory:
                    inp_hidden_states = fp_inps["hidden_states"][j].float()
                    inp_encoder_hidden_states = fp_inps["encoder_hidden_states"][j].float()
                    fp_hidden_states, fp_encoder_hidden_states = layer(hidden_states=inp_hidden_states.to(dev),
                                    encoder_hidden_states=inp_encoder_hidden_states.to(dev),
                                    temb=cache["temb"][j],
                                    image_rotary_emb=cache["image_rotary_emb"][j],
                                    attention_mask=cache["attention_mask"][j],
                                    )
                    fp_outs["hidden_states"].append(fp_hidden_states.to("cpu"))
                    fp_outs["encoder_hidden_states"].append(fp_encoder_hidden_states.to("cpu"))
                    del fp_hidden_states, fp_encoder_hidden_states
                    del inp_hidden_states, inp_encoder_hidden_states
                    torch.cuda.empty_cache()
                else:
                    if USE_QUANT_INPUT:
                        inputs = (quant_inps["hidden_states"][j].float(), cache["encoder_hidden_states"][j].float(), 
                                cache["temb"][j], cache["rotary_emb"][j])
                    else:
                        inputs = (fp_inps["hidden_states"][j].float(), cache["encoder_hidden_states"][j].float(), 
                                cache["temb"][j], cache["rotary_emb"][j])
                    fp_hidden_states = layer(*inputs)
                    fp_outs["hidden_states"].append(fp_hidden_states)


        set_linear_to_normal(layer)
        if args.diag_init == "sq_style":
            set_linear_diag_init(layer, args.diag_alpha)
        elif args.diag_init == "one_style":
            pass
        else:
            raise NotImplementedError

        layer = layer.to(dev)

        set_require_grad_all(layer, False)
        trained_params, paras_name = [], []
        if args.cali_trans:
            trained_params.append({"params": get_n_set_parameters_byname(layer, ["trans.linear", ]), "lr": args.flat_lr})
            paras_name.append("trans.linear")
        if args.add_diag:
            trained_params.append({"params": get_n_set_parameters_byname(layer, ["trans.diag_scale", ]), "lr": args.flat_lr})
            paras_name.append("trans.diag_scale")
        if args.lwc:
            trained_params.append({"params": get_n_set_parameters_byname(layer, ["clip_factor_w", ]), "lr": args.flat_lr * 10})
            paras_name.append("clip_factor_w")
        if args.lac:
            trained_params.append({"params": get_n_set_parameters_byname(layer, ["clip_factor_a", ]), "lr": args.flat_lr * 10})
            paras_name.append("clip_factor_a")

        optimizer = torch.optim.AdamW(trained_params)
        scheduler_main = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs * args.nsamples, eta_min=args.flat_lr * 1e-3)
        if args.warmup:
            scheduler_warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, total_iters=16)
            scheduler = torch.optim.lr_scheduler.ChainedScheduler([scheduler_warmup, scheduler_main])
        else:
            scheduler = scheduler_main

        if USE_DISTILL:
            for name, module in layer.named_modules():
                if "attn1" in name and isinstance(module, Attention):
                    module.set_processor(WanAttnProcessor2_0_Trainer())
        for epoch in range(args.epochs):
            if epoch == 0:
                print_memory_stats(f"Layer {i} Epoch {epoch} Start")
            mse = 0
            start_tick = time.time()
            with traincast():
                for j in range(args.nsamples):

                    if save_gpu_memory:
                        cur_hidden_states = fp_inps["hidden_states"][j].to(dev)
                        cur_encoder_hidden_states = fp_inps["encoder_hidden_states"][j].to(dev)
                        cur_fp_hidden_states = fp_outs["hidden_states"][j].to(dev)
                        cur_fp_encoder_hidden_states = fp_outs["encoder_hidden_states"][j].to(dev)
                        
                        quant_hidden_states, quant_encoder_hidden_states = layer(
                            hidden_states=cur_hidden_states,
                            encoder_hidden_states=cur_encoder_hidden_states,
                            temb=cache["temb"][j],
                            image_rotary_emb=cache["image_rotary_emb"][j],
                            attention_mask=cache["attention_mask"][j],
                        )
                    
                        loss = 0.5 * loss_func(cur_fp_hidden_states, quant_hidden_states) \
                            + 0.5 * loss_func(cur_fp_encoder_hidden_states, quant_encoder_hidden_states)
                    else:
                        
                        if USE_QUANT_INPUT:
                            inputs = (quant_inps["hidden_states"][j].float(), cache["encoder_hidden_states"][j].float(), 
                                      cache["temb"][j], cache["rotary_emb"][j])
                        else:
                            inputs = (fp_inps["hidden_states"][j].float(), cache["encoder_hidden_states"][j].float(), 
                                      cache["temb"][j], cache["rotary_emb"][j])
                        
                        globalvar.set_current_index(j)
                        quant_hidden_states = layer(*inputs)
                        loss = loss_func(fp_outs["hidden_states"][j], quant_hidden_states)
                        if USE_DISTILL:
                            attn_low_res, attn_top, _ = globalvar.get_attn_distil(j)
                            q_attn_low_res, q_attn_top = globalvar.get_current_attn()

                            loss += F.mse_loss(q_attn_low_res, attn_low_res) * 1e-4
                            loss += F.mse_loss(q_attn_top, attn_top) * 1e-4
                    mse += loss.detach().cpu()

                    optimizer.zero_grad()
                    loss.backward()

                    torch.nn.utils.clip_grad_norm_(
                        [p for group in trained_params for p in group["params"]], 
                        max_norm=1.0
                    )
                    optimizer.step()
                    scheduler.step()
                    
                    if save_gpu_memory:
                        del cur_hidden_states, cur_encoder_hidden_states
                        del cur_fp_hidden_states, cur_fp_encoder_hidden_states
                        del quant_hidden_states, quant_encoder_hidden_states
                        torch.cuda.empty_cache()
                    
            cur_lr = optimizer.state_dict()['param_groups'][0]['lr']
            logger.info(f"layer {i} lwc lac iter {epoch}, lr {cur_lr:.8f}  time {time.time() - start_tick:.6f}s, mse: {mse:.8f}" )

        if USE_DISTILL:
            for name, module in layer.named_modules():
                if "attn1" in name and isinstance(module, Attention):
                    module.set_processor(WanAttnProcessor2_0())
        fp_inps = fp_outs
        fp_outs = {
            "hidden_states": [],
        }

        if USE_QUANT_INPUT:
            with torch.no_grad():
                for j in range(args.nsamples):
                    inputs = (quant_inps["hidden_states"][j].float(), cache["encoder_hidden_states"][j].float(), 
                                    cache["temb"][j], cache["rotary_emb"][j])
                    quant_hidden_states = layer(*inputs)
                    quant_outs["hidden_states"].append(quant_hidden_states)
                quant_inps = quant_outs
                quant_outs = {
                    "hidden_states": [],
                }
        layers[i] = layer.to("cpu")
        layer_params = get_paras_dict_by_name(layer, required_names=paras_name)
        torch.save(layer_params, os.path.join(args.exp_dir, f"flat_parameters_layer_{i}.pth"))
        logger.info(f"saved parameters for layer {i} at {os.path.join(args.exp_dir, f'flat_parameters_layer_{i}.pth')}")
        del layer_params
        for name, param in layer.named_parameters():
            param.requires_grad = False
            if name in dtype_dict.keys():
                param.data = param.to(dtype_dict[name])
        del layer
        torch.cuda.empty_cache()
        print_memory_stats(f"After Layer {i}")

    del fp_inps, fp_outs
    gc.collect()
    torch.cuda.empty_cache()
    
    print_memory_stats("Final")
    model.to(dev)

    flat_parameters = {}
    for i in range(num_train_layer):
        layer_params = torch.load(os.path.join(args.exp_dir, f"flat_parameters_layer_{i}.pth"))
        flat_parameters[i] = layer_params
        del layer_params
    torch.save(flat_parameters, os.path.join(args.exp_dir, f"flat_parameters.pth"))
    logger.info(f"saved merged parameters at {os.path.join(args.exp_dir, 'flat_parameters.pth')}")

    for i in range(num_train_layer):
        layer_param_path = os.path.join(args.exp_dir, f"flat_parameters_layer_{i}.pth")
        if os.path.exists(layer_param_path):
            os.remove(layer_param_path)
            logger.info(f"removed layer parameter file: {layer_param_path}")

    return
