import os
import time
import gc
import functools
from contextlib import nullcontext

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers

from .attn_utils import EnhanceCogVideoXAttnProcessor2_0
import qdiff.s2quant.globalvar as globalvar
from diffusers.models.attention import Attention
from diffusers.models.attention_processor import CogVideoXAttnProcessor2_0
from .function_utils import set_require_grad_all, get_n_set_parameters_byname, get_paras_dict_by_name, check_params_grad
from .quant_utils import set_quantizer_state
from .quant_linear import QuantizedLinear
from .function_utils import get_init_scale

save_gpu_memory = False

def set_embed_to_device(model, device):
    model.patch_embed.to(device)
    model.time_embedding.to(device)
    model.embedding_dropout.to(device)
    model.time_proj.to(device)
    
def set_linear_to_ori(layer):
    for name, module in layer.named_modules():
        if isinstance(module, QuantizedLinear):
            module.ori_mode = True

def set_linear_to_normal(layer):
    for name, module in layer.named_modules():
        if isinstance(module, QuantizedLinear):
            module.ori_mode = False

def set_linear_diag_init(layer, alpha):
    for name, module in layer.named_modules():
        if isinstance(module, QuantizedLinear):
            module.trans.diag_scale.data = get_init_scale(module.linear.weight.abs().max(dim=0)[0],  
                                                          torch.ones_like(module.linear.weight.abs().max(dim=0)[0]).cuda() * 1e-5,
                                                          alpha)

def cali_s2_quant(args, model, dataloader, dev, logger):
    # 添加显存监控函数
    def print_memory_stats(prefix=""):
        logger.info(f"{prefix} Memory Stats:")
        logger.info(f"Allocated: {torch.cuda.memory_allocated(dev) / 1024**2:.2f} MB")
        logger.info(f"Reserved/Cached: {torch.cuda.memory_reserved(dev) / 1024**2:.2f} MB")
        logger.info(f"Max Allocated: {torch.cuda.max_memory_allocated(dev) / 1024**2:.2f} MB")
        logger.info(f"Max Reserved: {torch.cuda.max_memory_reserved(dev) / 1024**2:.2f} MB")
        
    # 在关键位置添加显存监控
    print_memory_stats("Initial")
    
    model.cpu()
    model.eval()

    # check trainable parameters
    for name, param in model.named_parameters():
        param.requires_grad = False

    # activate AMP
    if args.deactive_amp:
        dtype = torch.float32
        traincast = nullcontext
    else:
        dtype = torch.bfloat16
        traincast = functools.partial(torch.amp.autocast, device_type="cuda", dtype=dtype)

    layers = model.transformer_blocks
    layers[0] = layers[0].to(dev)
    set_embed_to_device(model, dev)
    # catch the first layer input
    inps = {
        "hidden_states": [],
        "encoder_hidden_states": [],
    }
    cache = {"i": 0,
             "temb": []}
    class Catcher(nn.Module):
        def __init__(self, module):
            super().__init__()
            self.module = module

        def forward(self, **kwargs):
            # import ipdb; ipdb.set_trace()
            # inps[cache["i"]] = inp
            if save_gpu_memory:
                inps["hidden_states"].append(kwargs["hidden_states"].to("cpu")) # to tuple
                inps["encoder_hidden_states"].append(kwargs["encoder_hidden_states"].to("cpu")) # to tuple
            else:
                inps["hidden_states"].append(kwargs["hidden_states"]) # to tuple
                inps["encoder_hidden_states"].append(kwargs["encoder_hidden_states"]) # to tuple
            cache["i"] += 1
            cache["image_rotary_emb"] = kwargs["image_rotary_emb"]
            cache["temb"].append(kwargs["temb"].float()) # move this to cache
            # cache["attention_kwargs"] = kwargs["attention_kwargs"]
            raise ValueError
        
    layers[0] = Catcher(layers[0])
    with torch.no_grad():
        for batch in dataloader:
            if cache["i"] >= args.nsamples:
                break
            try:
                input_dict = {}
                for key in batch.keys():
                    if torch.is_tensor(batch[key]):
                        input_dict[key] = batch[key].to(dev)
                    else:
                        input_dict[key] = batch[key]
                model(**input_dict)
            except ValueError:
                pass
    image_rotary_emb = cache["image_rotary_emb"]
    # attention_kwargs = cache["attention_kwargs"]

    '''if attention_mask is not None:
        attention_mask_batch = attention_mask.repeat(args.cali_bsz, 1, 1, 1).float()
    else:
        attention_mask_batch = None'''
    
    # move embedding layer and first layer to cpu
    layers[0] = layers[0].module
    layers[0] = layers[0].cpu()
    set_embed_to_device(model, "cpu")
    # raise ValueError("Only support for llama-2/Llama-3/qwen-2 now")
    del dataloader
    torch.cuda.empty_cache()

    # same input of first layer for fp model and quant model
    fp_inps = inps   # take output of fp model as input
    fp_outs = {
        "hidden_states": [],
        "encoder_hidden_states": [],
    }   # take output of fp model as input
    del inps
    torch.cuda.empty_cache()
    loss_func = torch.nn.MSELoss()
    # start training
    s2_parameters = {}
    num_train_layer = len(layers)
    
    use_temporal_loss = False
    USE_ATTN_WEIGHT = False
    def weighted_mse_loss(pred, target, weights):
        """计算加权MSE损失"""
        return (weights * (pred - target) ** 2).mean()
    for i in range(num_train_layer):
    # for i in range(3):
        logger.info(f"========= Layer {i} =========")
        print_memory_stats(f"Before Layer {i}")
        dtype_dict = {}
        layer = layers[i].to(dev)
        for name, param in layer.named_parameters():
            dtype_dict[name] = param.dtype
        with torch.no_grad():
            layer.float()

        set_linear_to_ori(layer)
       
        if USE_ATTN_WEIGHT:
            globalvar.clear_attn_maps()
            for name, module in layer.named_modules():
                if "attn" in name and isinstance(module, Attention):
                    module.set_processor(EnhanceCogVideoXAttnProcessor2_0())

            with torch.no_grad():
                for j in range(args.nsamples):
                    try:
                        _, _ = layer(hidden_states=fp_inps["hidden_states"][j].float(),
                                        encoder_hidden_states=fp_inps["encoder_hidden_states"][j].float(),
                                        temb=cache["temb"][j],
                                        image_rotary_emb=image_rotary_emb, 
                                        # attention_kwargs=attention_kwargs,
                                        )
                    except Exception as e:
                        # import ipdb; ipdb.set_trace()
                        continue

            for name, module in layer.named_modules():
                if "attn" in name and isinstance(module, Attention):
                    module.set_processor(CogVideoXAttnProcessor2_0())

        with torch.no_grad():
            for j in range(args.nsamples):
                # import ipdb; ipdb.set_trace()
                if save_gpu_memory:
                    inp_hidden_states = fp_inps["hidden_states"][j].float()
                    inp_encoder_hidden_states = fp_inps["encoder_hidden_states"][j].float()
                    fp_hidden_states, fp_encoder_hidden_states = layer(hidden_states=inp_hidden_states.to(dev),
                                    encoder_hidden_states=inp_encoder_hidden_states.to(dev),
                                    temb=cache["temb"][j],
                                    image_rotary_emb=image_rotary_emb, 
                                    # attention_kwargs=attention_kwargs,
                                    )
                    fp_outs["hidden_states"].append(fp_hidden_states.to("cpu"))
                    fp_outs["encoder_hidden_states"].append(fp_encoder_hidden_states.to("cpu"))
                    del fp_hidden_states, fp_encoder_hidden_states
                    del inp_hidden_states, inp_encoder_hidden_states
                    torch.cuda.empty_cache()
                else:
                    fp_hidden_states, fp_encoder_hidden_states = layer(hidden_states=fp_inps["hidden_states"][j].float(),
                                    encoder_hidden_states=fp_inps["encoder_hidden_states"][j].float(),
                                    temb=cache["temb"][j],
                                    image_rotary_emb=image_rotary_emb, 
                                    # attention_kwargs=attention_kwargs,
                                    )
                    fp_outs["hidden_states"].append(fp_hidden_states)
                    fp_outs["encoder_hidden_states"].append(fp_encoder_hidden_states)
                    # del fp_hidden_states, fp_encoder_hidden_states
                    # torch.cuda.empty_cache()

        set_linear_to_normal(layer)
        if args.diag_init == "sq_style":
            set_linear_diag_init(layer, args.diag_alpha)
        elif args.diag_init == "one_style":
            pass
        else:
            raise NotImplementedError

        layer = layer.to(dev)
        set_require_grad_all(layer, False)
        trained_params, paras_name = [], []
        if args.cali_trans:
            trained_params.append({"params": get_n_set_parameters_byname(layer, ["trans.linear", ]), "lr": args.lr})
            paras_name.append("trans.linear")
        if args.add_diag:
            trained_params.append({"params": get_n_set_parameters_byname(layer, ["trans.diag_scale", ]), "lr": args.lr})
            paras_name.append("trans.diag_scale")
        if args.lwc:
            trained_params.append({"params": get_n_set_parameters_byname(layer, ["clip_factor_w", ]), "lr": args.lr * 10})
            paras_name.append("clip_factor_w")
        if args.lac:
            trained_params.append({"params": get_n_set_parameters_byname(layer, ["clip_factor_a", ]), "lr": args.lr * 10})
            paras_name.append("clip_factor_a")

        optimizer = torch.optim.AdamW(trained_params)
        scheduler_main = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs * args.nsamples, eta_min=args.lr * 1e-3)
        if args.warmup:
            scheduler_warmup = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, total_iters=16)
            scheduler = torch.optim.lr_scheduler.ChainedScheduler([scheduler_warmup, scheduler_main])
        else:
            scheduler = scheduler_main
        for epoch in range(args.epochs):
            if epoch == 0:
                print_memory_stats(f"Layer {i} Epoch {epoch} Start")
            mse = 0
            start_tick = time.time()
            with traincast():
                for j in range(args.nsamples):
                    # 将当前需要的数据临时移到 GPU
                    if save_gpu_memory:
                        cur_hidden_states = fp_inps["hidden_states"][j].to(dev)
                        cur_encoder_hidden_states = fp_inps["encoder_hidden_states"][j].to(dev)
                        cur_fp_hidden_states = fp_outs["hidden_states"][j].to(dev)
                        cur_fp_encoder_hidden_states = fp_outs["encoder_hidden_states"][j].to(dev)
                        
                        quant_hidden_states, quant_encoder_hidden_states = layer(
                            hidden_states=cur_hidden_states,
                            encoder_hidden_states=cur_encoder_hidden_states,
                            temb=cache["temb"][j],
                            image_rotary_emb=image_rotary_emb,
                        )
                    
                        loss = 0.5 * loss_func(cur_fp_hidden_states, quant_hidden_states) \
                            + 0.5 * loss_func(cur_fp_encoder_hidden_states, quant_encoder_hidden_states)
                    else:
                        quant_hidden_states, quant_encoder_hidden_states = layer(
                            hidden_states=fp_inps["hidden_states"][j],
                            encoder_hidden_states=fp_inps["encoder_hidden_states"][j],
                            temb=cache["temb"][j],
                            image_rotary_emb=image_rotary_emb,
                        )
                        if USE_ATTN_WEIGHT:
                            loss = 0.5 * weighted_mse_loss(fp_outs["hidden_states"][j], quant_hidden_states, globalvar.get_attn_maps(j)) \
                                + 0.5 * loss_func(fp_outs["encoder_hidden_states"][j], quant_encoder_hidden_states)
                            # loss = weighted_mse_loss(fp_outs["hidden_states"][j], quant_hidden_states, globalvar.get_attn_maps(j))
                        else:
                            loss = 0.5 * loss_func(fp_outs["hidden_states"][j], quant_hidden_states) \
                                + 0.5 * loss_func(fp_outs["encoder_hidden_states"][j], quant_encoder_hidden_states)
                    mse += loss.detach().cpu()

                    optimizer.zero_grad()
                    loss.backward()

                    optimizer.step()
                    scheduler.step()
                    
                    if save_gpu_memory:
                        # 将数据移回 CPU 并释放显存
                        del cur_hidden_states, cur_encoder_hidden_states
                        del cur_fp_hidden_states, cur_fp_encoder_hidden_states
                        del quant_hidden_states, quant_encoder_hidden_states
                        torch.cuda.empty_cache()
                    
            cur_lr = optimizer.state_dict()['param_groups'][0]['lr']
            logger.info(f"layer {i} lwc lac iter {epoch}, lr {cur_lr:.8f}  time {time.time() - start_tick:.6f}s, mse: {mse:.8f}" )

        # fp_inps, fp_outs = fp_outs, fp_inps
        fp_inps = fp_outs
        fp_outs = {
            "hidden_states": [],
            "encoder_hidden_states": [],
        }
        layers[i] = layer.to("cpu")
        layer_params = get_paras_dict_by_name(layer, required_names=paras_name)
        torch.save(layer_params, os.path.join(args.exp_dir, f"s2_parameters_layer_{i}.pth"))
        logger.info(f"saved parameters for layer {i} at {os.path.join(args.exp_dir, f's2_parameters_layer_{i}.pth')}")
        del layer_params  # 立即释放内存
        for name, param in layer.named_parameters():
            param.requires_grad = False
            if name in dtype_dict.keys():
                param.data = param.to(dtype_dict[name])
        del layer
        torch.cuda.empty_cache()
        print_memory_stats(f"After Layer {i}")

    del fp_inps, fp_outs
    gc.collect()
    torch.cuda.empty_cache()
    
    print_memory_stats("Final")
    model.to(dev)

    # 最后合并所有层的参数（如果需要）
    s2_parameters = {}
    for i in range(num_train_layer):
        layer_params = torch.load(os.path.join(args.exp_dir, f"s2_parameters_layer_{i}.pth"))
        s2_parameters[i] = layer_params
        del layer_params
    torch.save(s2_parameters, os.path.join(args.exp_dir, f"s2_parameters.pth"))
    logger.info(f"saved merged parameters at {os.path.join(args.exp_dir, 's2_parameters.pth')}")

    # 删除单层参数文件
    for i in range(num_train_layer):
        layer_param_path = os.path.join(args.exp_dir, f"s2_parameters_layer_{i}.pth")
        if os.path.exists(layer_param_path):
            os.remove(layer_param_path)
            logger.info(f"removed layer parameter file: {layer_param_path}")

    return

