import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from torch.utils.data import DataLoader, RandomSampler
from datetime import datetime
import torch.nn as nn
import torch
from pathlib import Path
import argparse
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR
from kurtail.dataloader import R1Dataset
from kurtail.optimizer import AdamG

from rotation.rotation_utils import get_orthogonal_matrix
from rotation.model_utils import get_model

default_path = './kurtail'


def kurtosis(H):
    if len(H.shape) == 3:
        H = H.reshape(-1, H.shape[-1])
    # 计算每个激活向量的均值并中心化
    means = H.mean(dim=1, keepdim=True)  # shape (N, 1)
    stds = H.std(dim=1, keepdim=True)  # shape (N, 1)

    k = (((H - means) / (stds + 1e-6)) ** 4 - 3).abs()
    return k.mean()


class R1_QR(nn.Module):
    def __init__(self, hidden_size: int):
        super(R1_QR, self).__init__()
        self.hidden_size = hidden_size
        self.num_groups = hidden_size // 32
        self.block_r1 = []
        for _ in range(self.num_groups):
            self.block_r1.append(nn.Parameter(get_orthogonal_matrix(32, "hadamard").float()))
        # r1 将仅用于保存训练后的结果（detach），不参与计算图的构建
        self.r1 = None

    def forward(self, x):
        r1 = torch.block_diag(*self.block_r1).to(x.device)  # shape hidden x hidden
        o_x = x @ r1
        # 仍然保留一个非参数属性以便训练结束后保存（detach），但不用于计算图中
        self.r1 = r1.detach()
        return o_x


def train_R1(dataset, args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    R1 = R1_QR(hidden_size=args.hidden_size).to(device)

    optimizer = AdamG(R1.block_r1, lr=args.lr, momentum=args.mom, stiefel=True)
    if args.cos_lr:  # 设置余弦退火学习率调度器
        scheduler = CosineAnnealingLR(optimizer, T_max=args.ep, eta_min=0)

    dataloader = DataLoader(dataset,        # 创建数据加载器
                            batch_size=args.bsz,
                            num_workers=8,
                            prefetch_factor=3,
                            persistent_workers=True,
                            pin_memory=True)

    R1.train()
    print("---> start training R1 ")
    for epoch in range(args.ep):
        loss_log = []
        for batch_idx, batch_samples in enumerate(dataloader):
            batch_samples = batch_samples.to(device).float().reshape(-1, args.hidden_size)
            outputs = R1(batch_samples)
            outputs = outputs.reshape(outputs.shape[0], -1, 32)

            loss = kurtosis(outputs)

            # 如果 loss 为 NaN/Inf，打印诊断信息并跳过该步以避免破坏模型参数
            if torch.isnan(loss) or torch.isinf(loss):
                print(f"[WARN] got invalid loss at epoch {epoch} batch {batch_idx}: {loss}")
                # 打印部分统计信息帮助定位问题
                with torch.no_grad():
                    s = outputs
                    print(
                        f"outputs: min={s.min().item():.6f} max={s.max().item():.6f} mean={s.mean().item():.6f} std={s.std().item():.6f}")
                    br = R1.block_r1
                    print(
                        f"block_r1: min={br.min().item():.6f} max={br.max().item():.6f} mean={br.mean().item():.6f} std={br.std().item():.6f}")
                # 清理梯度并继续下一个 batch
                optimizer.zero_grad()
                continue

            loss_log.append(loss.detach().cpu())

            # 反向传播并进行梯度裁剪以防数值爆炸
            loss.backward()
            torch.nn.utils.clip_grad_norm_(R1.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()

        # 计算平均损失 并打印日志

        mean_loss = torch.stack(loss_log).mean()

        # 基础日志消息部分
        log_message = f'Epoch [{epoch+1}/{args.ep}], Loss: {mean_loss.item():.4f}, LR: {optimizer.param_groups[0]["lr"]:.4f}'
        if args.cos_lr:
            scheduler.step()
            log_message += f', LR: {scheduler.get_last_lr()[0]:.4f}'
        print(log_message)

    print("---> R1 training done ")

    return R1.r1.data


def parser_gen():
    parser = argparse.ArgumentParser()

    # General Arguments
    parser.add_argument('--model', type=str, default='meta-llama/Llama-2-7b-hf',
                        help='model name.')
    parser.add_argument('--calib_dataset', type=str, default='wikitext2',
                        choices=['wikitext2', 'ptb', 'c4'],
                        help='Dataset for Evaluation (default: wikitext2)',)
    parser.add_argument('--calib_sample', type=int, default=500,
                        help='Number of sample.')
    parser.add_argument('--hf_token', type=str, default=None)
    parser.add_argument('--ep', type=int, default=100,
                        help='Number of epochs for training (default: 100)')
    parser.add_argument('--bsz', type=int, default=256,
                        help='Batch-size for training (default: 256)')
    parser.add_argument('--lr', type=float, default=2,
                        help="Learning rate for training (default: 2)")
    parser.add_argument('--mom', type=float, default=0.9,
                        help='Momentum for training (default: 0.9)')
    parser.add_argument('--seed', type=int, default=0,
                        help='Random Seed for HuggingFace and PyTorch')
    parser.add_argument('--cos_lr', action=argparse.BooleanOptionalAction, default=False,
                        help='Whether to use cosine learning rate scheduler (default: True)')
    parser.add_argument('--save_model', action=argparse.BooleanOptionalAction, default=True,
                        help='Whether save learned r1 (default: True).')

    args = parser.parse_args()
    if '13b' in args.model:
        args.bsz = 128

    now = datetime.now()
    run_time = now.strftime("%Y-%m-%d_%H:%M:%S")
    setattr(args, 'run_time', run_time)

    # {run_time}.
    save_folder = f'{args.model.split("/")[-1]}-r1'
    setattr(args, 'save_folder', save_folder)

    data_path = os.path.join(
        default_path,
        f"train_data/{args.calib_dataset}_{args.calib_sample}samples/{args.model.split('/')[-1]}")
    setattr(args, 'data_path', data_path)
    print(f'---> data path: {data_path}')
    save_path = os.path.join(
        default_path,
        f"trained_rotation/{args.calib_dataset}_{args.calib_sample}samples/")
    setattr(args, 'save_path', save_path)

    return args, now


if __name__ == "__main__":
    args, run_time = parser_gen()

    model = get_model(args.model, args.hf_token)
    setattr(args, 'hidden_size', model.config.hidden_size)
    del model

    # 加载数据集
    dataset = R1Dataset(args.data_path, nsamples=args.calib_sample)
    print(f'---> start time: {run_time}')
    r1 = train_R1(dataset, args)
    finish_time = datetime.now()
    train_time = finish_time - run_time
    print(f"---> Training time: {train_time}")

    os.makedirs(f'./kurtail/logs/', exist_ok=True)
    with open(f'./kurtail/logs/train_time_log.txt', 'a') as f:
        f.write(
            f"{args.model.split('/')[-1]} train r1 time: {train_time} {args.calib_sample}sample \n")

    if args.save_model:  # 保存训练好的旋转矩阵
        Path(args.save_path).mkdir(parents=True, exist_ok=True)
        save_path = Path(args.save_path)
        save_name = f"{save_path}/{args.save_folder}_SBRQ.pt"
        torch.save({'R1': r1}, save_name)

        print(f"---> R1 has been saved in: {save_path}")
