import os

import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR

from .optimizer import SGDG
import geoopt

from qdit.quant import quantize_tensor_channel_group

def softmax_max_abs(X, beta=20):
    """
    X: [N, C]
    返回 softmax-加权的最大绝对值近似: [C]
    """
    abs_X = X.abs()  # [N, C]
    weights = F.softmax(beta * abs_X, dim=0)  # [N, C] along token dim
    return (weights * abs_X).sum(dim=0)  # [C]


def pairwise_kl(p_list):
    """
    p_list: list of [C] tensors
    计算 pairwise KL: sum_{i ≠ j} KL(p_i || p_j)
    """
    loss = 0.0
    T = len(p_list)
    for i in range(T):
        for j in range(T):
            if i == j:
                continue
            p = p_list[i]
            q = p_list[j]
            kl = F.kl_div(p.log(), q, reduction='batchmean')  # sum over C
            loss += kl
    return loss / (T * (T - 1))


def mean_kl(p_list):
    """
    p_list: list of [C] tensors
    计算 mean KL: sum_{i} KL(p_i || p_medium)
    """
    loss = 0.0
    T = len(p_list)
    medium_T = T // 2
    # medium_T = T - 1
    p_medium_T = p_list[medium_T]
    for i in range(T):
        if i == medium_T:
            continue
        p = p_list[i]
        kl = F.kl_div(p.log(), p_medium_T, reduction='batchmean')  # sum over C
        loss += kl
    return loss / (T - 1)


def loss_function_rec(X, Y, R, W, b, act_quant, args):
    """Compute Loss

    Args:
        X (torch.tensor): [B, N, C]
        R (torch.nn.Paramtere): [C, C]
    """

    X_hat = torch.matmul(X, R)
    W_hat = torch.matmul(W, R)

    X_tilde = act_quant(X_hat)
    W_tilde = quantize_tensor_channel_group(
        W_hat, 
        n_bits=args['n_bits'],
        sym=args['sym'],
        group_size=0,
        tiling=0,
    )

    Y_tilde = torch.nn.functional.linear(X_tilde, W_tilde, b)

    return torch.nn.functional.mse_loss(Y_tilde, Y)

def loss_function_kl(X, R, beta, topk, samples_index_list):
    """Compute Loss

    Args:
        X (torch.tensor): [B, T, N, C]
        R (torch.nn.Paramtere): [C, C]
        beta (int): softmax temperature for max-abs approx
    """
    # first_sample_index = [i*batch_size // 2 for i in range(timesteps)]
    # normal_samples_index_list = []
    # null_samples_index_list = []
    # for i in range(batch_size // 2):
    #     normal_samples_index_list.append([j + i for j in first_sample_index])
    #     null_samples_index_list.append([j + i  + X.shape[0] // 2 for j in first_sample_index])
    # samples_index_list = normal_samples_index_list + null_samples_index_list

    X = torch.stack([X[index] for index in samples_index_list])
            
    B, T, N, C = X.shape
    medium_T = T // 2
    
    total_loss = 0.0

    for b in range(B):
        p_list = []
        
        X_b_medium_t = X[b, medium_T] @ R
        a_medium_t = softmax_max_abs(X_b_medium_t, beta=beta)  # [C]
        _, topk_indices = torch.topk(a_medium_t, topk)
        
        for t in range(T):
            X_bt = X[b, t] @ R  # [N, C]
            a_t = softmax_max_abs(X_bt, beta=beta)[topk_indices]  # [topk]
            # a_t = softmax_max_abs(X_bt, beta=beta)  # [topk]
            p_t = F.softmax(a_t, dim=-1)  # [topk]
            p_list.append(p_t)
        # total_loss += pairwise_kl(p_list) * C
        total_loss += mean_kl(p_list) * topk
        # total_loss += mean_kl(p_list) * C

    return total_loss / B


def optimize_rotation_matrix(R, W, b, act_quant, n_batches, samples_index_list, args, quant_config_weight, prefix, ins_list, outs_list):
    R.requires_grad_(True)
    optimizer = SGDG([R], lr=args.learning_rate, stiefel=True, device=torch.device('cuda'))

    total_steps = args.epochs * n_batches

    # 线性衰减函数
    lr_lambda = lambda step: 1 - step / total_steps  # 从1衰减到0
    scheduler = LambdaLR(optimizer, lr_lambda)

    # 余弦退火调度器
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    #     optimizer, 
    #     T_max=total_steps,  # 总衰减步数
    #     eta_min=0  # 最小学习率（从初始学习率衰减到0）
    # )

    for epoch in range(args.epochs):
        rec_loss = 0.0
        kl_loss = 0.0
        total_loss = 0.0
        for batch_index in range(n_batches):
            cached_inps = ins_list[batch_index]
            cached_outs = outs_list[batch_index]

            loss_rec = loss_function_rec(cached_inps, cached_outs, R, W, b, act_quant, quant_config_weight)
            loss_kl = loss_function_kl(cached_inps, R, args.beta, args.topk, samples_index_list)
            loss_total = loss_rec + args.alpha * loss_kl
            
            optimizer.zero_grad()
            loss_total.backward()
            optimizer.step()
            scheduler.step()

            rec_loss += loss_rec.cpu().item()
            kl_loss += loss_kl.cpu().item()
            total_loss += loss_total.cpu().item()

        rec_loss /= n_batches
        kl_loss /= n_batches
        total_loss /= n_batches
        print(prefix + f' Epoch [{epoch+1}/{args.epochs}], Rec Loss {rec_loss: .6f} KL Loss {kl_loss: .6f} Total Loss {total_loss: .6f}')

    R.requires_grad_(False)
        