import os

import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR

from .optimizer import SGDG
import geoopt

from qdit.quant import quantize_tensor_channel_group

def softmax_max_abs(X, beta=20):
    """
    X: [N, C]
    返回 softmax-加权的最大绝对值近似: [C]
    """
    abs_X = X.abs()  # [N, C]
    weights = F.softmax(beta * abs_X, dim=0)  # [N, C] along token dim
    return (weights * abs_X).sum(dim=0)  # [C]


def pairwise_kl(p_list):
    """
    p_list: list of [C] tensors
    计算 pairwise KL: sum_{i ≠ j} KL(p_i || p_j)
    """
    loss = 0.0
    T = len(p_list)
    for i in range(T):
        for j in range(T):
            if i == j:
                continue
            p = p_list[i]
            q = p_list[j]
            kl = F.kl_div(p.log(), q, reduction='batchmean')  # sum over C
            loss += kl
    return loss / (T * (T - 1))


def mean_kl(p_list, median_timestep_index):
    """
    p_list: list of [C] tensors
    计算 mean KL: sum_{i} KL(p_i || p_medium)
    """
    loss = 0.0
    T = len(p_list)
    # medium_T = T // 2
    p_medium_T = p_list[median_timestep_index]
    for i in range(T):
        if i == median_timestep_index:
            continue
        p = p_list[i]
        kl = F.kl_div(p.log(), p_medium_T, reduction='batchmean')  # sum over C
        loss += kl
    return loss / (T - 1)


def loss_function_rec(X, Y, R, W, b, act_quant, args):
    """Compute Loss

    Args:
        X (torch.tensor): [B, N, C]
        R (torch.nn.Paramtere): [C, C]
    """
    X_hat = torch.matmul(X, R)
    W_hat = torch.matmul(W, R)

    X_tilde = act_quant(X_hat)
    W_tilde = quantize_tensor_channel_group(
        W_hat, 
        n_bits=args.wbits,
        exponential=args.exponential, 
        sym=args.w_sym,
        group_size=0,
        channel_group=args.weight_channel_group,
        clip_ratio=args.w_clip_ratio,
        tiling=args.tiling,
        quant_type=args.quant_type,
        quant_method=args.quant_method,
    )

    Y_tilde = torch.nn.functional.linear(X_tilde, W_tilde, b)

    return torch.nn.functional.mse_loss(Y_tilde, Y)

def loss_function_kl(X, R, beta, topk, calib_batch_size, n_timesteps_group, median_timestep_index):
    """Compute Loss

    Args:
        X (torch.tensor): [B, T, N, C]
        R (torch.nn.Paramtere): [C, C]
        beta (int): softmax temperature for max-abs approx
    """
    first_sample_index = [i*calib_batch_size // 2 for i in range(n_timesteps_group)]
    normal_samples_index_list = []
    null_samples_index_list = []
    for i in range(calib_batch_size // 2):
        normal_samples_index_list.append([j + i for j in first_sample_index])
        null_samples_index_list.append([j + i  + X.shape[0] // 2 for j in first_sample_index])
    samples_index_list = normal_samples_index_list + null_samples_index_list
    # print(samples_index_list)
    X = torch.stack([X[index] for index in samples_index_list])
            
    B, T, N, C = X.shape
    # medium_T = T // 2
    
    total_loss = 0.0

    for b in range(B):
        p_list = []
        
        X_b_medium_t = X[b, median_timestep_index] @ R
        a_medium_t = softmax_max_abs(X_b_medium_t, beta=beta)  # [C]
        _, topk_indices = torch.topk(a_medium_t, topk)
        
        for t in range(T):
            X_bt = X[b, t] @ R  # [N, C]
            a_t = softmax_max_abs(X_bt, beta=beta)[topk_indices]  # [topk]
            p_t = F.softmax(a_t, dim=-1)  # [topk]
            p_list.append(p_t)
        # total_loss += pairwise_kl(p_list) * C
        total_loss += mean_kl(p_list, median_timestep_index) * topk

    return total_loss / B


def optimize_rotation_matrix(R, W, b, act_quant, cached_path, n_batches, args, prefix, median_timestep_index, ins_list, outs_list):
    R.requires_grad_(True)
    optimizer = SGDG([R], lr=args.learning_rate, stiefel=True, device=torch.device(f'cuda:{args.cuda}'))

    # optimizer = geoopt.optim.RiemannianAdam(
    #     [R], weight_decay=0, lr=args.learning_rate, stabilize=10,
    # )

    total_steps = args.epochs * n_batches

    # 定义线性衰减函数
    lr_lambda = lambda step: 1 - step / total_steps  # 从1衰减到0
    scheduler = LambdaLR(optimizer, lr_lambda)
    # 使用余弦退火调度器
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    #     optimizer, 
    #     T_max=total_steps,  # 总衰减步数
    #     eta_min=0  # 最小学习率（从初始学习率衰减到0）
    # )

    for epoch in range(args.epochs):
        # print('-'*30 + f' Start Epoch {epoch+1} ' + '-'*30)
        rec_loss = 0.0
        kl_loss = 0.0
        total_loss = 0.0
        for batch_index in range(n_batches):
            # cached_inps = torch.load(os.path.join(cached_path, f'cached_inps_t{batch_index}.pt'))
            # cached_outs = torch.load(os.path.join(cached_path, f'cached_outs_t{batch_index}.pt'))
            cached_inps = ins_list[batch_index]
            cached_outs = outs_list[batch_index]
            # print(cached_inps.shape)
            loss_rec = loss_function_rec(cached_inps, cached_outs, R, W, b, act_quant, args)
            loss_kl = loss_function_kl(cached_inps, R, args.beta, args.topk, args.calib_batch_size, args.n_timesteps_group, median_timestep_index)
            loss_total = loss_rec + args.alpha * loss_kl
            
            optimizer.zero_grad()
            loss_total.backward()
            optimizer.step()
            scheduler.step()

            rec_loss += loss_rec.cpu().item()
            kl_loss += loss_kl.cpu().item()
            total_loss += loss_total.cpu().item()

            # print(torch.matmul(R, R.T))

        rec_loss /= n_batches
        kl_loss /= n_batches
        total_loss /= n_batches
        print(prefix + f'Epoch [{epoch+1}/{args.epochs}], Rec Loss {rec_loss: .6f} KL Loss {kl_loss: .6f} Total Loss {total_loss: .6f}')

    R.requires_grad_(False)
        