import logging
import random
import os

import torch
import torch.nn.functional as F
from itertools import product

from qdit.qLinearLayer import QLinearLayer

logger = logging.getLogger(__name__)


class SaveActivationHook:
    
    def __init__(self, samples_index_list):
        self.hook_handle = None
        self.outputs = []
        self.samples_index_list = samples_index_list

    def __call__(self, module, module_in, module_out):
        '''
        the input shape could be [BS, C] or [BS, N_token, C]
        only keep the channel_dim for reduced saved act size
        '''
        # C = module_in[0].shape[-1]
        # data = module_in[0].reshape([-1,C]).abs().max(dim=0)[0]  # [C]

        X = module_in[0]    # [B*T, N, C]
        if len(X.shape) == 3:
            X = torch.stack([X[index] for index in self.samples_index_list])    # [B, T, N, C]
            # 计算每个样本在每个时间步的 abs().max(dim=2)，结果 shape 是 [B, T, C]
            a_bt = X.abs().max(dim=2)[0]  # [B, T, C]

            # 沿 batch 维求平均，得到 [T, C]
            # a_list = a_bt.mean(dim=0)  # [T, C]
            a_list = a_bt.max(dim=0)[0]  # [T, C]
        else:
            a_list = []

        self.outputs.append(a_list)

    def clear(self):
        self.outputs = []


def apply_func_to_submodules(module, class_type, function, parent_name="", return_d=None, **kwargs):
    """
    Recursively iterates through all submodules of a PyTorch module and applies a hook function
    if the submodule matches the specified class type. The parent name is appended to the submodule name.

    Args:
        module (torch.nn.Module): The PyTorch module to iterate through.
        class_type (type): The class type to match against submodules.
        function (callable): The function to apply if a submodule matches the class type.
        parent_name (str): The name of the parent module (used for recursion).
    """

    for name, submodule in module.named_children():
        full_name = f"{parent_name}.{name}" if parent_name else name
        parent_module = module

        # INFO: pass from the parent call into func
        if 'name' in kwargs:
            kwargs['name']=name
        if 'full_name' in kwargs:
            kwargs['full_name'] = full_name
        if 'parent_module' in kwargs:
            kwargs['parent_module'] = module
        # if 'quant_param_dict' in kwargs:
            # kwargs['quant_param_dict'] = quant_param_dict
        if isinstance(submodule, class_type):
            if return_d is not None:
                return_d[full_name] = function(submodule, **kwargs)
            else:
                function(submodule, **kwargs)

        # Recursively apply the function to submodules
        apply_func_to_submodules(submodule, class_type, function, full_name, return_d, **kwargs)

    if return_d is not None:
        return return_d
    

def add_hook_to_module_(module, hook_cls, samples_index_list):
    hook = hook_cls(samples_index_list)
    hook.hook_handle = module.register_forward_hook(hook)
    return hook


def get_train_samples(calib_n, sample_data):
    # get the real number of timesteps (especially for DDIM)
    nsteps = len(sample_data["ts"])     # 100
    timesteps = list(range(0, nsteps, 1))  # [0, 4, 8, ..., 96]

    logger.info(f'Selected {len(timesteps)} steps from {nsteps} sampling steps')

    xs_lst = [sample_data["xs"][i][:calib_n] for i in timesteps]
    ts_lst = [sample_data["ts"][i][:calib_n] for i in timesteps]
    ys_lst = [sample_data["y"][i][:calib_n] for i in timesteps]
    xs = torch.cat(xs_lst, dim=0)
    ts = torch.cat(ts_lst, dim=0)
    ys = torch.cat(ys_lst, dim=0)

    return xs, ts, ys


# def group_layer_activations(X, n_group):
#     """
#     Args:
#         X: Tensor of shape [T, C]
#         n_group: int, number of groups to form

#     Returns:
#         partition: list of integers indicating the exclusive end index of each group
#     """
#     T, C = X.shape
#     assert n_group <= T, "n_group must be <= T"

#     # Step 1: Softmax normalization over channel dimension
#     X = F.softmax(X, dim=1)  # shape: [T, C]
#     # X = X / (X.norm(p=2, dim=1, keepdim=True) + 1e-8)  # L2 归一化到单位范数
#     # tau = 0.2  # 温度系数，越小分布越尖锐，越大越平缓
#     # X = F.softmax(X / tau, dim=1)

#     # Step 2: Each time step is a group initially
#     groups = [[t] for t in range(T)]

#     # Helper function: compute pairwise symmetric KL divergence between two groups
#     def pairwise_symmetric_kl(g1, g2):
#         p_set = X[g1]  # shape: [len(g1), C]
#         q_set = X[g2]  # shape: [len(g2), C]
#         total_kl = 0.0
#         count = 0
#         for i, j in product(range(len(p_set)), range(len(q_set))):
#             p = p_set[i]
#             q = q_set[j]
#             kl1 = F.kl_div(p.log(), q, reduction='sum')
#             kl2 = F.kl_div(q.log(), p, reduction='sum')
#             total_kl += (kl1 + kl2) / 2
#             count += 1
#         return total_kl / count if count > 0 else torch.tensor(0.0, device=X.device)

#     # Step 3: Iteratively merge the pair with minimum pairwise symmetric KL
#     while len(groups) > n_group:
#         min_kl = float('inf')
#         merge_idx = -1

#         for i in range(len(groups) - 1):
#             kl = pairwise_symmetric_kl(groups[i], groups[i+1])
#             if kl < min_kl:
#                 min_kl = kl
#                 merge_idx = i

#         # Merge selected adjacent groups
#         groups[merge_idx] = groups[merge_idx] + groups[merge_idx + 1]
#         del groups[merge_idx + 1]

#     # Step 4: Construct partition list (exclusive end indices of groups except last)
#     # partition = [max(g) for g in groups[:-1]]
#     # return partition

#     return groups


def group_layer_activations(X, n_group, max_group_size=6):
    """
    Args:
        X: Tensor of shape [T, C]
        n_group: int, number of groups to form
        max_group_size: int, maximum allowed size of each group (soft constraint)

    Returns:
        groups: list of lists, each inner list contains the indices of timesteps in that group
    """
    T, C = X.shape
    assert n_group <= T, "n_group must be <= T"
    assert max_group_size >= 1, "max_group_size must be >= 1"

    # Step 1: Softmax normalization
    X = F.softmax(X, dim=1)
    # tau = 0.1  # 温度系数，越小分布越尖锐，越大越平缓
    # X = F.softmax(X / tau, dim=1)

    # Step 2: Start with each time step as its own group
    groups = [[t] for t in range(T)]

    # KL divergence between two groups
    def pairwise_symmetric_kl(g1, g2):
        p_set = X[g1]  # shape: [len(g1), C]
        q_set = X[g2]  # shape: [len(q2), C]
        total_kl = 0.0
        count = 0
        for i, j in product(range(len(p_set)), range(len(q_set))):
            p = p_set[i]
            q = q_set[j]
            kl1 = F.kl_div(p.log(), q, reduction='sum')
            kl2 = F.kl_div(q.log(), p, reduction='sum')
            total_kl += (kl1 + kl2) / 2
            count += 1
        return total_kl / count if count > 0 else torch.tensor(0.0, device=X.device)

    # ----- 第一阶段：遵守 max_group_size -----
    while len(groups) > n_group:
        min_kl = float('inf')
        merge_idx = -1

        for i in range(len(groups) - 1):
            if len(groups[i]) + len(groups[i+1]) <= max_group_size:
                kl = pairwise_symmetric_kl(groups[i], groups[i+1])
                if kl < min_kl:
                    min_kl = kl
                    merge_idx = i

        # 如果没有找到符合 max_group_size 的组合，退出第一阶段
        if merge_idx == -1:
            break

        groups[merge_idx] = groups[merge_idx] + groups[merge_idx + 1]
        del groups[merge_idx + 1]

    # ----- 第二阶段：忽略 max_group_size，强制合并到 n_group -----
    while len(groups) > n_group:
        min_kl = float('inf')
        merge_idx = -1

        for i in range(len(groups) - 1):
            kl = pairwise_symmetric_kl(groups[i], groups[i+1])
            if kl < min_kl:
                min_kl = kl
                merge_idx = i

        groups[merge_idx] = groups[merge_idx] + groups[merge_idx + 1]
        del groups[merge_idx + 1]

    return groups


def group_timesteps(cali_xs, cali_ts, cali_ys, timesteps, model, args):
    n_iterations = args.calib_n // args.calib_batch_size

    cali_xs_batch_list = []
    cali_ts_batch_list = []
    cali_ys_batch_list = []
    for i in range(n_iterations):
        inds = []
        for t in timesteps:
            idx = torch.where(cali_ts == t)[0][i*args.calib_batch_size:(i+1)*args.calib_batch_size]
            inds.extend(idx.tolist())
            
        # rearrange data
        cali_xs_batch = cali_xs[inds]
        cali_ts_batch = cali_ts[inds]
        cali_ys_batch = cali_ys[inds]
        normal_index = torch.where(cali_ys_batch != 1000)[0]
        null_index = torch.where(cali_ys_batch == 1000)[0]
        
        cali_xs_batch = torch.cat([cali_xs_batch[normal_index], cali_xs_batch[null_index]], 0)
        cali_ts_batch = torch.cat([cali_ts_batch[normal_index], cali_ts_batch[null_index]], 0)
        cali_ys_batch = torch.cat([cali_ys_batch[normal_index], cali_ys_batch[null_index]], 0)
        
        cali_xs_batch_list.append(cali_xs_batch)
        cali_ts_batch_list.append(cali_ts_batch)
        cali_ys_batch_list.append(cali_ys_batch)
    
    first_sample_index = [i*args.calib_batch_size // 2 for i in range(100)]
    normal_samples_index_list = []
    null_samples_index_list = []
    for i in range(args.calib_batch_size // 2):
        normal_samples_index_list.append([j + i for j in first_sample_index])
        null_samples_index_list.append([j + i  + cali_xs_batch.shape[0] // 2 for j in first_sample_index])
    samples_index_list = normal_samples_index_list + null_samples_index_list
    
    kwargs = {
        'hook_cls': SaveActivationHook,
        'samples_index_list': samples_index_list
    }

    hook_d = apply_func_to_submodules(model,
                            class_type=QLinearLayer,  # add hook to all objects of this cls
                            function=add_hook_to_module_,
                            return_d={},
                            **kwargs
                            )
    
    count = 1
    for cali_xs_batch, cali_ts_batch, cali_ys_batch in zip(cali_xs_batch_list, cali_ts_batch_list, cali_ys_batch_list):
        print(f'Calibrating Batch {count} ...')
        with torch.no_grad():
            _ = model.forward_with_cfg(cali_xs_batch, cali_ts_batch, cali_ys_batch, args.cfg_scale)
        count += 1
    
    save_d = {}
    for k,v in hook_d.items():
        if v.outputs[0] == []:
            v.hook_handle.remove()
            continue
        if 'fc2' in k:
            v.hook_handle.remove()
            continue

        save_d[k] = torch.stack(v.outputs, dim=0)  # [B, T, C]
        # save_d[k] = save_d[k].mean(dim=0)   # [T, C]
        save_d[k] = save_d[k].max(dim=0)[0]   # [T, C]
        # logger.info(f'layer_name: {k}, hook_input_shape: {v.outputs[0].shape}')
        v.hook_handle.remove()

    group_partitions_dict = {}
    for block_index in range(28):
        print(f'Grouping X of Block {block_index}\'s qkv ...')
        v = f'blocks.{block_index}.attn.qkv'
        qkv_activations = save_d[v]     # [T, C]
        qkv_partition = group_layer_activations(qkv_activations, args.n_timesteps_group)
        group_partitions_dict[v] = qkv_partition
        print('Groups: ', qkv_partition)

        print(f'Grouping X of Block {block_index}\'s proj ...')
        v = f'blocks.{block_index}.attn.proj'
        proj_activations = save_d[v]     # [T, C]
        proj_partition = group_layer_activations(proj_activations, args.n_timesteps_group)
        group_partitions_dict[v] = proj_partition
        print('Groups: ', proj_partition)

        print(f'Grouping X of Block {block_index}\'s fc1 ...')
        v = f'blocks.{block_index}.mlp.fc1'
        fc1_activations = save_d[v]     # [T, C]
        fc1_partition = group_layer_activations(fc1_activations, args.n_timesteps_group)
        group_partitions_dict[v] = fc1_partition
        print('Groups: ', fc1_partition)

    return group_partitions_dict


def select_timesteps(group_partitions_dict, timesteps):
    layers_select_timestep = {}

    for block_index in range(28):
        print(f'Selecting Timesteps of Block {block_index}\'s qkv ...')
        v = f'blocks.{block_index}.attn.qkv'
        group_list = group_partitions_dict[v]
        timesteps_list = []
        for group in group_list:
            timestep = timesteps[random.choice(group)]
            timesteps_list.append(timestep)
        layers_select_timestep[v] = timesteps_list
        print('Timesteps: ', timesteps_list)

        print(f'Selecting Timesteps of Block {block_index}\'s proj ...')
        v = f'blocks.{block_index}.attn.proj'
        group_list = group_partitions_dict[v]
        timesteps_list = []
        for group in group_list:
            timestep = timesteps[random.choice(group)]
            timesteps_list.append(timestep)
        layers_select_timestep[v] = timesteps_list
        print('Timesteps: ', timesteps_list)

        print(f'Selecting Timesteps of Block {block_index}\'s fc1 ...')
        v = f'blocks.{block_index}.mlp.fc1'
        group_list = group_partitions_dict[v]
        timesteps_list = []
        for group in group_list:
            timestep = timesteps[random.choice(group)]
            timesteps_list.append(timestep)
        layers_select_timestep[v] = timesteps_list
        print('Timesteps: ', timesteps_list)

    return layers_select_timestep


def select_timesteps_layer(v, group_partitions_dict, timesteps):
    group_list = group_partitions_dict[v]
    timesteps_list = []
    for group in group_list:
        timestep = timesteps[random.choice(group)]
        timesteps_list.append(timestep)
    print('Timesteps: ', timesteps_list)

    return timesteps_list


def get_layer_calib_data(cali_xs, cali_ts, cali_ys, timesteps_list, args, median_timestep=515):
    n_iterations = args.calib_n // args.calib_batch_size
    cali_xs_batch_list = []
    cali_ts_batch_list = []
    cali_ys_batch_list = []
    for i in range(n_iterations):
        inds = []
        for t in timesteps_list:
            idx = torch.where(cali_ts == t)[0][i*args.calib_batch_size:(i+1)*args.calib_batch_size]
            inds.extend(idx.tolist())
        # print('inds: ', inds)
            
        # rearrange data
        cali_xs_batch = cali_xs[inds].cpu()
        cali_ts_batch = cali_ts[inds].cpu()
        cali_ys_batch = cali_ys[inds].cpu()
        normal_index = torch.where(cali_ys_batch != 1000)[0].cpu()
        null_index = torch.where(cali_ys_batch == 1000)[0].cpu()
        # print('normal_index: ', normal_index)
        # print('null_index: ', null_index)
        
        cali_xs_batch = torch.cat([cali_xs_batch[normal_index], cali_xs_batch[null_index]], 0)
        cali_ts_batch = torch.cat([cali_ts_batch[normal_index], cali_ts_batch[null_index]], 0)
        cali_ys_batch = torch.cat([cali_ys_batch[normal_index], cali_ys_batch[null_index]], 0)
        
        cali_xs_batch_list.append(cali_xs_batch)
        cali_ts_batch_list.append(cali_ts_batch)
        cali_ys_batch_list.append(cali_ys_batch)
    
    # 计算每个元素与目标值的绝对差
    diff = torch.abs(torch.tensor(timesteps_list) - median_timestep)
    # 找到最小差值对应的下标
    median_timestep_index = torch.argmin(diff).item()
    
    return cali_xs_batch_list, cali_ts_batch_list, cali_ys_batch_list, median_timestep_index


def get_layer_calib_data_dynamic(cali_xs, cali_ts, cali_ys, args, v, group_partitions_dict, timesteps, median_timestep=515):
    n_iterations = args.calib_n // args.calib_batch_size
    cali_xs_batch_list = []
    cali_ts_batch_list = []
    cali_ys_batch_list = []
    median_timestep_index_list = []
    for i in range(n_iterations):
        inds = []
        timesteps_list = select_timesteps_layer(v, group_partitions_dict, timesteps)
        for t in timesteps_list:
            idx = torch.where(cali_ts == t)[0][i*args.calib_batch_size:(i+1)*args.calib_batch_size]
            inds.extend(idx.tolist())
        # print('inds: ', inds)
            
        # rearrange data
        cali_xs_batch = cali_xs[inds].cpu()
        cali_ts_batch = cali_ts[inds].cpu()
        cali_ys_batch = cali_ys[inds].cpu()
        normal_index = torch.where(cali_ys_batch != 1000)[0].cpu()
        null_index = torch.where(cali_ys_batch == 1000)[0].cpu()
        # print('normal_index: ', normal_index)
        # print('null_index: ', null_index)
        
        cali_xs_batch = torch.cat([cali_xs_batch[normal_index], cali_xs_batch[null_index]], 0)
        cali_ts_batch = torch.cat([cali_ts_batch[normal_index], cali_ts_batch[null_index]], 0)
        cali_ys_batch = torch.cat([cali_ys_batch[normal_index], cali_ys_batch[null_index]], 0)
        
        cali_xs_batch_list.append(cali_xs_batch)
        cali_ts_batch_list.append(cali_ts_batch)
        cali_ys_batch_list.append(cali_ys_batch)
    
        # 计算每个元素与目标值的绝对差
        diff = torch.abs(torch.tensor(timesteps_list) - median_timestep)
        # 找到最小差值对应的下标
        median_timestep_index = torch.argmin(diff).item()
        median_timestep_index_list.append(median_timestep_index)
    
    return cali_xs_batch_list, cali_ts_batch_list, cali_ys_batch_list, median_timestep_index_list


def save_rotation_matrix(model, save_dir):
    rotation_matrices = {}
    blocks = model.blocks

    for block_index in range(28):
        rotation_matrices[f'blocks.{block_index}.attn.qkv'] = blocks[block_index].attn.qkv_rotation_matrix
        rotation_matrices[f'blocks.{block_index}.attn.proj'] = blocks[block_index].attn.proj_rotation_matrix
        rotation_matrices[f'blocks.{block_index}.mlp.fc1'] = blocks[block_index].mlp.fc1_rotation_matrix
        rotation_matrices[f'blocks.{block_index}.mlp.fc2'] = blocks[block_index].mlp.fc2_rotation_matrix
    
    torch.save(rotation_matrices, os.path.join(save_dir, 'rotation_matrices.pt'))


def save_scale_matrix(model, save_dir):
    scale_matrices = {}
    blocks = model.blocks

    for block_index in range(28):
        scale_matrices[f'blocks.{block_index}.attn.qkv'] = blocks[block_index].attn.qkv_scale_mask
        scale_matrices[f'blocks.{block_index}.attn.proj'] = blocks[block_index].attn.proj_scale_mask
        scale_matrices[f'blocks.{block_index}.mlp.fc1'] = blocks[block_index].mlp.fc1_scale_mask
        scale_matrices[f'blocks.{block_index}.mlp.fc2'] = blocks[block_index].mlp.fc2_scale_mask
    
    torch.save(scale_matrices, os.path.join(save_dir, 'scaling_matrices.pt'))