
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from torch.utils.data import DataLoader, RandomSampler
from datetime import datetime
import torch.nn as nn
import torch
from pathlib import Path
import argparse
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR
from kurtail.dataloader import R1Dataset
from kurtail.optimizer import AdamG

from rotation.rotation_utils import get_orthogonal_matrix
from rotation.model_utils import get_model

default_path = './kurtail'


def kurtosis(H):
    if len(H.shape) == 3:
        H = H.reshape(-1, H.shape[-1])
    # 计算每个激活向量的均值和标准
    means = H.mean(dim=1, keepdim=True)  # shape (N, 1)
    stds = H.std(dim=1, keepdim=True)  # shape (N, 1)

    # 计算每个激活向量的峰度
    kurtosis = (((H - means) / (stds + 1e-6)) ** 4 - 1.8).abs()  # 均匀分布的峰度（系数）为常数1.8

    # 计算平均峰度
    avg_kurtosis = kurtosis.mean()

    return avg_kurtosis


class R1_QR(nn.Module):
    def __init__(self, hidden_size: int):
        super(R1_QR, self).__init__()
        self.hidden_size = hidden_size
        self.r1 = nn.Parameter(torch.eye(hidden_size))

    def forward(self, x):
        o_x = torch.matmul(x, self.r1)
        return o_x


def train_R1(dataset, args):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    R1 = R1_QR(hidden_size=args.hidden_size).to(device)
    R1.r1.data = get_orthogonal_matrix(
        args.hidden_size, args.init_mode, device).float()

    optimizer = AdamG(R1.parameters(), lr=args.lr, momentum=args.mom, stiefel=True)
    if args.cos_lr:  # 设置余弦退火学习率调度器
        scheduler = CosineAnnealingLR(optimizer, T_max=args.ep, eta_min=0)

    dataloader = DataLoader(dataset,        # 创建数据加载器
                            batch_size=args.bsz,
                            num_workers=8,
                            prefetch_factor=3,
                            persistent_workers=True,
                            pin_memory=True)

    R1.train()
    print("---> start training R1 ")
    for epoch in range(args.ep):
        loss_log = []
        for batch_idx, batch_samples in enumerate(dataloader):
            batch_samples = batch_samples.to(device).float().reshape(-1, args.hidden_size)
            outputs = R1(batch_samples)

            loss = kurtosis(outputs)
            loss_log.append(loss.detach().cpu())

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        # 计算平均损失 并打印日志

        mean_loss = torch.stack(loss_log).mean()

        # 基础日志消息部分
        log_message = f'Epoch [{epoch+1}/{args.ep}], Loss: {mean_loss.item():.4f}, LR: {optimizer.param_groups[0]["lr"]:.4f}'
        if args.cos_lr:
            scheduler.step()
            log_message += f', LR: {scheduler.get_last_lr()[0]:.4f}'
        print(log_message)

    print("---> R1 training done ")

    return R1.r1.data


def parser_gen():
    parser = argparse.ArgumentParser()

    # General Arguments
    parser.add_argument('--model', type=str, default='meta-llama/Llama-2-7b-hf',
                        # choices=supported_models,
                        help='model name.')
    parser.add_argument('--calib_dataset', type=str, default='wikitext2',
                        choices=['wikitext2', 'ptb', 'c4'],
                        help='Dataset for Evaluation (default: wikitext2)',)
    parser.add_argument('--calib_sample', type=int, default=500,
                        help='Number of sample.')
    parser.add_argument('--hf_token', type=str, default=None)
    parser.add_argument('--ep', type=int, default=100,
                        help='Number of epochs for training (default: 10)')
    parser.add_argument('--bsz', type=int, default=256,
                        help='Batch-size for training (default: 128)')
    parser.add_argument('--lr', type=float, default=2,
                        help="Learning rate for training (default: 2, same as spinquant)")
    parser.add_argument('--mom', type=float, default=0.9,
                        help='Momentum for training (default: 0.9)')
    parser.add_argument('--seed', type=int, default=0,
                        help='Random Seed for HuggingFace and PyTorch')
    parser.add_argument('--cos_lr', action=argparse.BooleanOptionalAction, default=False,
                        help='Whether to use cosine learning rate scheduler (default: False)')
    parser.add_argument('--init_mode', type=str, default='hadamard', choices=['hadamard', 'random'],
                        help='Optimization method (default: hadamard)')
    parser.add_argument('--save_model', action=argparse.BooleanOptionalAction, default=True,
                        help='Whether save learned r1 (default: True).')

    args = parser.parse_args()
    if '13b' in args.model:
        args.bsz = 128

    now = datetime.now()
    run_time = now.strftime("%Y-%m-%d_%H:%M:%S")
    setattr(args, 'run_time', run_time)

    # {run_time}.
    save_folder = f'{args.model.split("/")[-1]}-r1'
    setattr(args, 'save_folder', save_folder)

    data_path = os.path.join(
        default_path,
        f"train_data/{args.calib_dataset}_{args.calib_sample}samples/{args.model.split('/')[-1]}")
    setattr(args, 'data_path', data_path)
    print(f'---> data path: {data_path}')
    save_path = os.path.join(
        default_path,
        f"trained_rotation/{args.calib_dataset}_{args.calib_sample}samples/")
    setattr(args, 'save_path', save_path)

    return args, now


if __name__ == "__main__":
    args, run_time = parser_gen()

    model = get_model(args.model, args.hf_token)
    setattr(args, 'hidden_size', model.config.hidden_size)
    del model

    # 加载数据集
    dataset = R1Dataset(args.data_path)
    print(f'---> start time: {run_time}')
    r1 = train_R1(dataset, args)
    finish_time = datetime.now()
    train_time = finish_time - run_time
    print(f"---> Training time: {train_time}")

    os.makedirs(f'./kurtail/logs/', exist_ok=True)
    with open(f'./kurtail/logs/train_time_log.txt', 'a') as f:
        f.write(
            f'{args.model.split("/")[-1]} train r1 time: {train_time} {args.calib_sample}sample \n')

    if args.save_model:  # 保存训练好的旋转矩阵
        Path(args.save_path).mkdir(parents=True, exist_ok=True)
        save_path = Path(args.save_path)
        save_name = f"{save_path}/{args.save_folder}.pt"
        torch.save({'R1': r1}, save_name)

        print(f"---> R1 has been saved in: {save_name}")
