import argparse
import json
import os
import random
import numpy as np

os.environ["TORCH_LOGS"] = "+dynamo,output_code,graph_breaks,recompiles"

import torch
import torch.nn.functional as F
from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler
from diffusers.utils import export_to_video
from torchao.quantization import (
    autoquant,
    quantize_,
    int8_weight_only,
    int8_dynamic_activation_int8_weight,
    int8_dynamic_activation_int4_weight,
    int8_dynamic_activation_int8_semi_sparse_weight,
    int4_weight_only,
    float8_dynamic_activation_float8_weight,
    float8_weight_only,
    fpx_weight_only,
)
from torchao.quantization.quant_api import PerRow
from torchao.sparsity import sparsify_
from torchao.quantization.smoothquant import swap_linear_with_smooth_fq_linear, smooth_fq_linear_to_inference

from utils import cleanup_tmp_directory, benchmark_fn, pretty_print_results, print_memory, reset_memory

from qdiff.s2quant.utils import s2quant_model, set_ignore_quantize
from qdiff.s2quant.utils import quantize_linear as falt_quantize_linear
from qdiff.s2quant.args_utils import get_config
# Set high precision for float32 matrix multiplications.
# This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
torch.set_float32_matmul_precision("high")


CONVERT_DTYPE = {
    "fp16": lambda module: module.to(dtype=torch.float16),
    "bf16": lambda module: module.to(dtype=torch.bfloat16),
    "fp8wo": lambda module: quantize_(module, float8_weight_only()),
    "fp8dq": lambda module: quantize_(module, float8_dynamic_activation_float8_weight()),
    "fp8dqrow": lambda module: quantize_(module, float8_dynamic_activation_float8_weight(granularity=PerRow())),
    "fp6_e3m2": lambda module: quantize_(module, fpx_weight_only(3, 2)),
    "fp5_e2m2": lambda module: quantize_(module, fpx_weight_only(2, 2)),
    "fp4_e2m1": lambda module: quantize_(module, fpx_weight_only(2, 1)),
    "int8wo": lambda module: quantize_(module, int8_weight_only()),
    "int8dq": lambda module: quantize_(module, int8_dynamic_activation_int8_weight()),
    "int4dq": lambda module: quantize_(module, int8_dynamic_activation_int4_weight()),
    "int4wo": lambda module: quantize_(module, int4_weight_only()),
    "autoquant": lambda module: autoquant(module, error_on_unseen=False),
    "sparsify": lambda module: sparsify_(module, int8_dynamic_activation_int8_semi_sparse_weight()),
    "smoothquant": lambda module: module,  # 初始转换将在校准过程中完成
    "w4a8_smoothquant": lambda module: quantize_linear(module, device=module.device),
    "w4a6_smoothquant": lambda module: quantize_linear(module, device=module.device),
}


def set_seed(seed=3047):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def get_calibration_data(pipe, num_samples=5, cache_path=None, top_k=10):
    """生成用于校准的数据样本并收集transformer的输入
    
    Args:
        pipe: 模型pipeline
        num_samples: 校准样本数量
        cache_path: 校准数据缓存路径，如果存在则直接加载
        top_k: 选取差异最大的top_k个hidden_states
    """
    if cache_path and os.path.exists(cache_path):
        print(f"从缓存加载校准数据: {cache_path}")
        all_inputs = torch.load(cache_path)
        return all_inputs

    calibration_prompts = [
        "A panda playing guitar in a bamboo forest",
        "A spaceship flying through colorful nebula",
        "A medieval castle on a cloudy day",
        "A cat wearing sunglasses on a beach",
        "A robot painting on an easel in a studio",
        "A cheerful panda wearing a traditional Chinese vest, sitting on a bamboo stool while skillfully \
        playing an acoustic guitar in a misty bamboo forest at sunrise, with golden light filtering through the leaves",
        "A sleek, metallic spaceship with glowing blue engines soaring through a vibrant nebula \
        filled with swirling purple, pink and turquoise cosmic clouds, distant stars twinkling in the background",
        "A majestic stone medieval castle perched on a rocky cliff, with towering spires and flying buttresses, \
        surrounded by swirling storm clouds and occasional lightning strikes, ancient flags waving in the wind",
        "A sophisticated Siamese cat lounging on a striped beach chair, wearing heart-shaped designer sunglasses \
        and a tiny straw hat, with crystal clear turquoise waves and white sandy beach in the background",
        "A futuristic robot with copper and chrome finish, delicately holding a paintbrush in its articulated fingers, \
        creating an impressionist masterpiece on a wooden easel in a sun-lit artist's studio filled with art supplies"
    ]
    
    all_inputs = []
    hidden_states_list = []  # 存储所有hidden_states
    
    guidance_scale = 6
    num_inference_steps = 50

    # 定义时间步区间
    time_intervals = [
        (0, 0),
        (1, 1),
        (2, 2),
        (3, 3),
        (4, 4),
        (5, 5),
        (6, 6),
        (7, 7),
        (8, 8),
        (9, 9),
        (10, 10),
        (11, 11),
        (12, 12),
        (13, 13),
        (14, 14),
        (15, 15),
        (16, 16),
        (17, 17),
        (18, 18),
        (19, 19),
        (20, 20),
        (21, 21),
        (22, 22),
        (23, 23),
        (24, 24),
        (25, 25),
        (26, 26),
        (27, 27),
        (28, 28),
        (29, 29),
        (30, 30),
        (31, 31),
        (32, 32),
        (33, 33),
        (34, 34),
        (35, 35),
        (36, 36),
        (37, 37),
        (38, 38),
        (39, 39),
        (40, 40),
        (41, 41),
        (42, 42),
        (43, 43),
        (44, 44),
        (45, 45),
        (46, 46),
        (47, 47),
        (48, 48),
        (49, 49)
    ]
    
    differences_by_interval = {i: [] for i in range(len(time_intervals))}
    salient_list = []
    hessian_list = []

    for prompt_idx, prompt in enumerate(calibration_prompts):
        all_inputs = []  # 只存储当前prompt的输入
        hidden_states_list = []
        
        # 设置原始模型的钩子
        def wrapped_forward(*args, **kwargs):
            input_dict = {}
            for key in kwargs.keys():
                if torch.is_tensor(kwargs[key]):
                    input_dict[key] = kwargs[key].detach()
                else:
                    input_dict[key] = kwargs[key]
            all_inputs.append(input_dict)
            return original_forward(*args, **kwargs)
            
        def wrapped_forward_last_layer(*args, **kwargs):
            output = original_forward_last_layer(*args, **kwargs)
            hidden_states_list.append((output[0].detach(), output[1].detach()))
            return output
        
        # 注册钩子
        original_forward = pipe.transformer.forward
        original_forward_last_layer = pipe.transformer.transformer_blocks[-1].forward
        
        pipe.transformer.forward = wrapped_forward
        pipe.transformer.transformer_blocks[-1].forward = wrapped_forward_last_layer

        # 运行推理以收集输入数据
        with torch.no_grad():
            _ = pipe(
                prompt=prompt,  # 为了快速测试，只使用第一个提示
                guidance_scale=guidance_scale,
                use_dynamic_cfg=True,
                num_inference_steps=num_inference_steps,
                generator=torch.Generator().manual_seed(3047),
            )
    
        # 恢复原始forward方法
        pipe.transformer.forward = original_forward
        pipe.transformer.transformer_blocks[-1].forward = original_forward_last_layer
        
        # 在每个时间区间内计算差异
        for interval_idx, (start_step, end_step) in enumerate(time_intervals):
            for step in range(start_step, end_step + 1):
                i = step
                
                # 获取原始和量化后的hidden states
                curr_hidden, curr_encoder_hidden = hidden_states_list[i]
                if i == 0:
                    pre_hidden, pre_encoder_hidden = 0., 0.
                else:
                    pre_hidden, pre_encoder_hidden = hidden_states_list[i-1]
                # 计算MSE和显著性
                # 归一化 hidden 和 encoder hidden
                curr_hidden_norm = F.normalize(curr_hidden, p=2, dim=-1)
                curr_encoder_hidden_norm = F.normalize(curr_encoder_hidden, p=2, dim=-1)

                if i == 0:
                    # 对于第一步，设置显著性为0
                    hidden_salient = torch.tensor(0.0, device=curr_hidden.device)
                    encoder_salient = torch.tensor(0.0, device=curr_hidden.device)
                else:
                    # 计算显著性并归一化
                    pre_hidden_norm = F.normalize(pre_hidden, p=2, dim=-1)
                    pre_encoder_hidden_norm = F.normalize(pre_encoder_hidden, p=2, dim=-1)
                    
                    hidden_salient = torch.mean((curr_hidden_norm - pre_hidden_norm) ** 2)
                    encoder_salient = torch.mean((curr_encoder_hidden_norm - pre_encoder_hidden_norm) ** 2)

                _, _, dim = curr_hidden_norm.shape
                curr_hidden_flat = curr_hidden_norm.reshape(dim, -1)
                curr_encoder_flat = curr_encoder_hidden_norm.reshape(dim, -1)

                # 计算自相关矩阵并归一化
                hidden_corr = curr_hidden_flat @ curr_hidden_flat.t()
                encoder_corr = curr_encoder_flat @ curr_encoder_flat.t()

                hidden_hessian = torch.mean(F.normalize(hidden_corr, p=2, dim=1))
                encoder_hessian = torch.mean(F.normalize(encoder_corr, p=2, dim=1))

                # 归一化各个组件
                total_salient = hidden_salient + encoder_salient
                total_hessian = hidden_hessian + encoder_hessian

                salient_list.append(total_salient)
                hessian_list.append(total_hessian)
        
        # 清理当前prompt的内存
        del all_inputs
        del hidden_states_list
        torch.cuda.empty_cache()

    # min-max归一化两个list,注意这是list
    stacked_salient = torch.stack(salient_list)
    stacked_hessian = torch.stack(hessian_list)
    # 分别取出最大最小值
    min_s, max_s = torch.min(stacked_salient), torch.max(stacked_salient)
    min_h, max_h = torch.min(stacked_hessian), torch.max(stacked_hessian)
    # 对原 list 中的每个元素做归一化，结果仍然是 list
    salient_list = [(s - min_s) / (max_s - min_s) for s in salient_list]
    hessian_list = [(h - min_h) / (max_h - min_h) for h in hessian_list]
    for prompt_idx in range(len(calibration_prompts)):
        for interval_idx, (start_step, end_step) in enumerate(time_intervals):
            for step in range(start_step, end_step + 1):
                i = step
                sample_metric = salient_list[prompt_idx*num_inference_steps+step] * hessian_list[prompt_idx*num_inference_steps+step]
                sample_metric = -torch.log(1-sample_metric+1e-8)
                differences_by_interval[interval_idx].append((prompt_idx*num_inference_steps+step, sample_metric.item()))
    selected_indices = []

    num_steps = num_inference_steps
    
    # 选择top_k个样本
    while len(selected_indices) < top_k:
        step = np.random.choice(num_steps)
        
        # 获取该时间步的所有差异值
        step_diffs = [x for x in differences_by_interval[step]]
        
        if step_diffs:
            # 选择该时间步中差异最大的样本
            best_sample = max(step_diffs, key=lambda x: x[1])
            selected_indices.append(best_sample[0])
            differences_by_interval[step].remove(best_sample)
    selected_indices.sort()
    # import ipdb; ipdb.set_trace()
    selected_inputs = []
    
    # 使用nonlocal变量来追踪状态
    class Counter:
        def __init__(self):
            self.select_id = 0
            self.infer_id = 0
    
    counter = Counter()
    original_forward = pipe.transformer.forward
    def wrapped_forward(*args, **kwargs):
        # import ipdb; ipdb.set_trace()
        if counter.select_id < len(selected_indices) and counter.infer_id == selected_indices[counter.select_id]:
            input_dict = {}
            for key in kwargs.keys():
                if torch.is_tensor(kwargs[key]):
                    input_dict[key] = kwargs[key].detach().cpu()
                else:
                    input_dict[key] = kwargs[key]
            selected_inputs.append(input_dict)
            counter.select_id += 1
        counter.infer_id += 1
        return original_forward(*args, **kwargs)
    pipe.transformer.forward = wrapped_forward
    for prompt in calibration_prompts:
        with torch.no_grad():
            _ = pipe(
                prompt=prompt,
                guidance_scale=guidance_scale,
                use_dynamic_cfg=True,
                num_inference_steps=num_inference_steps,
                generator=torch.Generator().manual_seed(3047),
            )
    # import ipdb; ipdb.set_trace()
    pipe.transformer.forward = original_forward
    random.shuffle(selected_inputs)
    if cache_path:
        print(f"保存校准数据到: {cache_path}")
        os.makedirs(os.path.dirname(cache_path), exist_ok=True)
        torch.save(selected_inputs, cache_path)

    return selected_inputs

def load_pipeline(model_id, dtype, device, quantize_vae, compile, fuse_qkv, resume_s2=False, use_gptq=False, resume_gptq=False, exp_name=None):
    # 1. Load pipeline
    pipe = CogVideoXPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)
    pipe.scheduler = CogVideoXDDIMScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
    # pipe.set_progress_bar_config(disable=True)
    pipe.vae.enable_tiling()
    pipe.vae.enable_slicing()
    if fuse_qkv:
        pipe.fuse_qkv_projections()

    # 2. 对于 SmoothQuant，需要特殊处理
    if dtype in ["s2quant_w4a4", "s2quant_w4a6"]:
        wbit = 4
        abit = 6 if dtype == "s2quant_w4a6" else 4
        config = get_config()
        config.update_from_args(wbit=wbit, abit=abit, model_id=model_id, exp_name=exp_name)
        cache_path = f"{config.exp_dir}/{model_id.replace('/', '_')}_calib_data.pt"
        # 获取校准数据
        calib_data = get_calibration_data(pipe, cache_path=cache_path, top_k=40)
        # import ipdb; ipdb.set_trace()
        pipe.to("cpu")
        # 执行校准
        s2quant_model(pipe.transformer, calib_data, wbit=wbit, abit=abit, resume_s2=resume_s2, use_gptq=use_gptq, resume_gptq=resume_gptq,
                        model_id=model_id, exp_name=exp_name)
        pipe.to(device)
    else:
        wbit = 4
        abit = 4
        # 获取校准数据，使用缓存
        config = get_config()
        config.update_from_args(wbit=wbit, abit=abit, model_id=model_id, exp_name=exp_name)
        pass

    if compile:
        pipe.transformer.to(memory_format=torch.channels_last)
        pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
        # VAE cannot be compiled due to: https://gist.github.com/a-r-r-o-w/5183d75e452a368fd17448fcc810bd3f#file-test_cogvideox_torch_compile-py-L30

    return pipe, config.exp_dir


def run_inference(pipe, prompt):
    guidance_scale = 6
    num_inference_steps = 50

    with torch.no_grad():
        video = pipe(
            prompt=prompt,
            guidance_scale=guidance_scale,
            use_dynamic_cfg=True,
            num_inference_steps=num_inference_steps,
            generator=torch.Generator().manual_seed(3047),  # https://arxiv.org/abs/2109.08203
        )
    return video


def main(model_id, dtype, device, quantize_vae, compile, fuse_qkv, resume_s2=False, use_gptq=False, resume_gptq=False, exp_name=None, prompts_file=None):
    set_seed(3047)
    reset_memory(device)

    # 1. Load pipeline
    pipe, cache_path = load_pipeline(model_id, dtype, device, quantize_vae, compile, fuse_qkv, resume_s2, use_gptq, resume_gptq, exp_name)

    print_memory(device)

    torch.cuda.empty_cache()
    model_memory = round(torch.cuda.memory_allocated() / 1024**3, 3)

    if prompts_file and os.path.exists(prompts_file):
        with open(prompts_file, 'r') as f:
            prompts = [line.strip() for line in f.readlines() if line.strip()]
        # 创建以prompts文件名命名的子文件夹
        prompt_name = os.path.basename(prompts_file)  # 首先获取文件名
        prompt_name = os.path.splitext(prompt_name)[0]  # 然后移除扩展名
        prompts_folder = os.path.join(cache_path, prompt_name)
        os.makedirs(prompts_folder, exist_ok=True)
    else:
        prompts = ["A cat wearing sunglasses on a beach."]  # 默认prompt
        prompts_folder = cache_path

    # import ipdb; ipdb.set_trace()
    for id, prompt in enumerate(prompts):
        video = run_inference(pipe, [prompt])
        # 使用prompt的前部分作为文件名
        video_filename = f"{prompt}-0.mp4"
        # video_filename = f"sample_{id}.mp4"
        export_to_video(video.frames[0], os.path.join(prompts_folder, video_filename), fps=8)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_id",
        type=str,
        default="THUDM/CogVideoX-5b",
        # choices=["THUDM/CogVideoX-2b", "THUDM/CogVideoX-5b"],
        help="Hub model or path to local model for which the benchmark is to be run.",
    )
    parser.add_argument(
        "--exp_name",
        type=str,
        default="exp_test",
        help="Experiment name.",
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="bf16",
        choices=[
            "fp16",
            "bf16",
            "fp8wo",
            "fp8dq",
            "fp8dqrow",
            "fp6_e3m2",
            "fp5_e2m2",
            "fp4_e2m1",
            "int8wo",
            "int8dq",
            "int4dq",
            "int4wo",
            "autoquant",
            "sparsify",
            "s2quant_w4a6",
            "s2quant_w4a6",
        ],
        help="Inference or Quantization type.",
    )
    parser.add_argument("--device", type=str, default="cuda", help="Device to run inference on.")
    parser.add_argument(
        "--resume_s2",
        action="store_true",
        default=False,
        help="Resume from checkpoint.",
    )
    parser.add_argument(
        "--use_gptq",
        action="store_true",
        default=False,
        help="Use GPTQ.",
    )
    parser.add_argument(
        "--resume_gptq",
        action="store_true",
        default=False,
        help="Resume gptq from checkpoint.",
    )
    parser.add_argument(
        "--quantize_vae",
        action="store_true",
        default=False,
        help="Whether or not to quantize the CogVideoX VAE. Can lead to worse decoding results in some quantization cases.",
    )
    parser.add_argument(
        "--compile",
        action="store_true",
        default=False,
        help="Whether or not to torch.compile the models. For our experiments with CogVideoX, we only compile the transformer.",
    )
    parser.add_argument(
        "--fuse_qkv",
        action="store_true",
        default=False,
        help="Whether or not to fuse the QKV projection layers into one larger layer.",
    )
    parser.add_argument(
        "--prompts_file",
        type=str,
        default=None,
        help="Path to a text file containing prompts (one per line).",
    )
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()

    main(args.model_id, args.dtype, args.device, args.quantize_vae, args.compile, args.fuse_qkv, 
         args.resume_s2, args.use_gptq, args.resume_gptq, args.exp_name, args.prompts_file)
    cleanup_tmp_directory()