import argparse
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import os
import json
import time
from modelscope import AutoModelForCausalLM, AutoTokenizer
from modelscope.msdatasets import MsDataset
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
import math
import shutil

from src_1.utils import load_partition_result
from src_1.clustering import clustering_process, get_projection_activation_graphs
from src_1.data_processor import DataProcessor
from src_1.mask_select import MaskSelecter


def partition_list_to_tenosr(all_partition_list, layers, neuron_num):
    """
    Convert the partition result to tensor.

    :param all_partition_list: the partition result of each layer
    :param layers: the number of layers
    :param neuron_num: the number of neurons in each layer
    :return: the tensor of partition result
    """
    partition_tensors = []
    for layer in range(layers):
        partition_list = all_partition_list[layer]
        partition_tensor = torch.empty(neuron_num, dtype=torch.long)
        for cluster_id, neuron_list in enumerate(partition_list):
            partition_tensor[neuron_list] = cluster_id
        partition_tensor = F.one_hot(partition_tensor, num_classes=len(partition_list)).float()   
        partition_tensors.append(partition_tensor)
    return partition_tensors


def load_all_partition_to_tensor(partition_folder_path, layers,neuron_num):
    """
    Load the partition result of each layer and convert it to tensors. In one-hot coding.

    :param partition_folder_path: the folder path of partition result
    :param neuron_num: the number of neurons in each layer
    :param layers: the number of layers
    :return: a list of tensors, each tensor is the partition result of a layer
    """
    partition_tensors = []
    for layer in range(layers):
        partition_file_path = os.path.join(partition_folder_path, f'layer_{layer}.parti')
        partition_list = load_partition_result(partition_file_path, index_from=0)
        partition_tensor = torch.empty(neuron_num, dtype=torch.long)
        for cluster_id, neuron_list in enumerate(partition_list):
            partition_tensor[neuron_list] = cluster_id

        # apply one-hot coding
        partition_tensor = F.one_hot(partition_tensor, num_classes=len(partition_list)).float()
        partition_tensors.append(partition_tensor)
    return partition_tensors

def evaluate(
    model, 
    test_dataloader, 
    device, 
):
    model.eval()  # Set model to evaluation mode
    total_eval_loss = 0
    progress_bar_eval = tqdm(test_dataloader, desc="Evaluating", leave=False)

    with torch.no_grad(): # Disable gradient calculation for efficiency
        for batch in progress_bar_eval:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = input_ids.clone()

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss = outputs.loss
            total_eval_loss += loss.item()

    avg_eval_loss = total_eval_loss / len(test_dataloader)
    eval_perplexity = math.exp(avg_eval_loss)
    print(f"Average Validation Loss: {avg_eval_loss:.4f} | Validation Perplexity: {eval_perplexity:.2f}")
    return total_eval_loss

# 训练函数
def train(
    model, 
    train_dataloader,
    test_dataloader, 
    epochs, 
    lr, 
    p,
    device, 
    save_path, 
    mode='normal',
    neighbor_p=0.3,
    clustering_result=None, 
    mlp_keyword=None,
    tokenizer = None):

    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    best_loss = float("inf")  

    mask_selecter = MaskSelecter(mode, neuron_dim=model.config.intermediate_size, device=device)

    previous_save_path = None
    for epoch in range(epochs):
        total_loss = 0
        progress_bar_train = tqdm(train_dataloader, desc="Training", leave=False)
        for batch in progress_bar_train:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = input_ids.clone()
        
            optimizer.zero_grad()
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            loss = outputs.loss
            total_loss += loss.item()
            loss.backward()

            if mode != 'normal':
                for layer in range(model.config.num_hidden_layers):
                        for name,param in model.named_parameters():
                            if str(layer) in name and mlp_keyword in name and param.requires_grad:
                                mask = mask_selecter.select_mask(
                                    param.grad, 
                                    p, 
                                    neuron_activation_graph = clustering_result[layer] if clustering_result is not None else None, 
                                    cluster_tensor = clustering_result[layer] if clustering_result is not None else None, 
                                    neighbor_p = neighbor_p, 
                                    neuron_dim = model.config.intermediate_size
                                    )
                                param.grad.mul_(mask)

            optimizer.step()   
            mask_rate = mask_selecter.get_mask_pass_ratio()
            progress_bar_train.set_postfix({'loss': loss.item(), 'mask_rate': mask_rate})

            torch.cuda.empty_cache()
        print(f"Epoch {epoch+1}, Training Loss: {total_loss/len(train_dataloader):.4f}")

        # 在测试数据上计算 loss
        test_loss = evaluate(model, test_dataloader, device)
        print(f"Epoch {epoch+1}, Test Loss: {test_loss:.4f}")

        # 如果测试 loss 下降，则保存模型
        if test_loss < best_loss:
            model_save_path = os.path.join(save_path, f"epoch_{epoch+1}_test_loss_{test_loss:.4f}")
            if os.path.exists(model_save_path) == False:
                os.makedirs(model_save_path, exist_ok=True)
            best_loss = test_loss
            model.save_pretrained(model_save_path)
            if tokenizer is not None:
                tokenizer.save_pretrained(model_save_path)
            print(f"模型已保存到临时路径 {model_save_path}，当前最佳测试 Loss: {best_loss:.4f}")
            # delete previous path
            if previous_save_path is not None:
                shutil.rmtree(previous_save_path)
            previous_save_path = model_save_path
        else:
            print(f"没有改善，不存模型")
            break
            # # 恢复到之前的最佳模型参数
            # model = AutoModelForCausalLM.from_pretrained(temp_model_path, trust_remote_code=True).to(device)

    # # 最终保存最佳模型
    # model.save_pretrained(save_path)
    print(f"最终模型已保存到 {save_path}")


def train_gmt(
    model, 
    train_dataloader,
    test_dataloader, 
    epochs, 
    lr, 
    p, # keep rate
    device, 
    save_path, 
    mode='gmt',
    mlp_keyword=None,
    N=3,
    tokenizer = None):

    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    best_loss = float("inf")

    #使用数据迭代器代替加载器，以便手动控制批次拉取
    train_iterator = iter(train_dataloader)  
    #每N个mini batch 算一个step
    total_steps = len(train_dataloader) // N

    mask_selecter = MaskSelecter(mode, neuron_dim=model.config.intermediate_size, device=device)
    previous_save_path = None
    for epoch in range(epochs):
        model.train()
        total_epoch_loss = 0
        progress_bar_train = tqdm(range(total_steps), desc=f"Epoch {epoch + 1} Training", leave=False)
        for step in progress_bar_train:
            # Gama=0
            optimizer.zero_grad()
            total_accumulation_loss = 0
            # 梯度累计循环
            for _ in range(N):
                try:
                    batch = next(train_iterator)
                except StopIteration:
                    # 如果迭代器耗尽，重新初始化
                    train_iterator = iter(train_dataloader)
                    batch = next(train_iterator)

                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = input_ids.clone()

                outputs = model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels
                )
                loss = outputs.loss
                loss = loss / N
                loss.backward()
                total_accumulation_loss += loss.item()            

            for layer in range(model.config.num_hidden_layers):
                        for name,param in model.named_parameters():
                            if str(layer) in name and mlp_keyword in name and param.requires_grad:
                                mask = mask_selecter.select_mask(
                                    param.grad, 
                                    p, 
                                    neuron_activation_graph = None, 
                                    cluster_tensor = None, 
                                    neuron_dim = model.config.intermediate_size
                                    )
                                param.grad.mul_(mask)

            optimizer.step()

            total_epoch_loss += total_accumulation_loss *N # 乘以N还原loss尺度
            mask_rate = mask_selecter.get_mask_pass_ratio()
            progress_bar_train.set_postfix({'avg_loss':total_accumulation_loss, 'mask_rate':f'{mask_rate:.2%}'})
            torch.cuda.empty_cache()

        print(f"Epoch {epoch+1}, Training Loss: {total_epoch_loss/len(train_dataloader):.4f}")

        # 在测试数据上计算 loss
        test_loss = evaluate(model, test_dataloader, device)
        print(f"Epoch {epoch+1}, Test Loss: {test_loss:.4f}")

        # 如果测试 loss 下降，则保存模型
        if test_loss < best_loss:
            this_save_path = f"epoch_{epoch+1}_test_loss_{test_loss:.4f}"
            model_save_path = os.path.join(save_path, this_save_path)
            if os.path.exists(model_save_path) == False:
                os.makedirs(model_save_path, exist_ok=True)
            best_loss = test_loss
            model.save_pretrained(model_save_path)
            if tokenizer is not None:
                tokenizer.save_pretrained(model_save_path)
            print(f"模型已保存到临时路径 {model_save_path}，当前最佳测试 Loss: {best_loss:.4f}")
            # delete previous path
            if previous_save_path is not None:
                shutil.rmtree(previous_save_path)
            previous_save_path = model_save_path
        else:
            print(f"没有改善，不存模型")
            break
            # # 恢复到之前的最佳模型参数
            # model = AutoModelForCausalLM.from_pretrained(temp_model_path, trust_remote_code=True).to(device)
        
    # # 最终保存最佳模型
    # model.save_pretrained(save_path)
    print(f"最终模型已保存到 {save_path}")

# 主函数
def main():
    parser = argparse.ArgumentParser(description="Fine-tune a causal language model.")
    parser.add_argument("--data_path", type=str, required=True, help="Path to the training data.")
    parser.add_argument("--subset_name", type=str, required=False, help="Name of the subset to use.")
    parser.add_argument("--train_size", type=int, required=True, help="Number of training examples.")
    # parser.add_argument("--test_data_path", type=str, required=True, help="Path to the test data.") 
    parser.add_argument("--test_size", type=int, required=True, help="Number of test examples.")
    parser.add_argument("--model_name", type=str, required=True, help="Name of the pre-trained model.")
    parser.add_argument("--mlp_keyword", type=str, required=True, help="Keyword to identify MLP parameters' name.")
    parser.add_argument("--batch_size", type=int, default=8, help="Batch size for training.")
    parser.add_argument("--max_length", type=int, default=512, help="Maximum length of input sequences.")
    parser.add_argument("--epochs", type=int, default=3, help="Number of training epochs.")
    parser.add_argument("--learning_rate", type=float, default=2e-4, help="Learning rate for training.")
    parser.add_argument("--p", type=float, default=0.1, help="Gradient mask rate, DECK parameter.")
    parser.add_argument("--save_path", type=str, required=True, help="Path to save the fine-tuned model.")
    parser.add_argument("--gpu", type=str, default="0", help="GPU device ID to use.")  
    parser.add_argument("--mode", type=str, required=True, help="Mode of tuning, 'deck'/'normal'/'random'/'highest'/'deck_no_cluster'/'gmt'")
    parser.add_argument("--neighbor_p", type=float, default=0.3, help="Used in DECK, neighbor gradient contribution rate.")
    parser.add_argument("--cluster_mode", type=str, default="file", help="file/data, load/compute cluster")
    parser.add_argument("--cluster_folder", type=str, required=True, help="Path to the store/load clustering result folder.")
    parser.add_argument("--graph_path", type=str, required=False, help="Path to projected coactivation graphs.")
    # parser.add_argument("--cluster_inter_path", type=str, default=None, help="Path to store the intermidiate clustering result.")
    parser.add_argument("--end_batch", type=int, default=3, help="Number of batches to used in coactivation measurement.")
    parser.add_argument("--min_rank", type=int, default=4, help="Activations of top min_rank tokens are not considers.")
    parser.add_argument("--k_nn", type=int, default=10, help="Number of nearest neighbors in constructing K-NN graph.")
    parser.add_argument("--k_cluster", type=int, default=16, help="Desired number of clusters.")
    parser.add_argument("--sigma", type=int, default=3, help="For pruning outliers in spectral clustering, ie 3-sigma rule.")
    parser.add_argument("--cluster_from_layer", type=int, default=0, help="Cluster from which layer.")
    args = parser.parse_args()

    # 设置 GPU
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    os.environ["TOKENIZERS_PARALLELISM"] = 'true'
    device = "cuda" if torch.cuda.is_available() else "cpu"
   
    model = AutoModelForCausalLM.from_pretrained(args.model_name, trust_remote_code=True, device_map="auto",torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=True)

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # create save dir
    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path, exist_ok=True)

    # 加载训练数据集
    dataset = MsDataset.load(args.data_path, subset_name=args.subset_name, split='train', trust_remote_code=True)
    if not hasattr(dataset, 'select'):
        dataset = dataset.to_hf_dataset()

    # define data processor
    data_processor = DataProcessor(args.data_path)
    train_dataset = dataset.select(range(args.train_size))
    train_dataset = train_dataset.map(
        lambda x: data_processor.format_and_tokenize(x, tokenizer, args.max_length),
        batched=True,
        )
    train_dataset.set_format(type='torch')

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        # collate_fn=lambda batch: collate_fn(batch, tokenizer)
    )

    # 加载测试数据集
    # test_dataset = MsDataset.load(args.test_data_path, subset_name=args.subset_name, split='test', trust_remote_code=True).select(range(1000))
    test_dataset = dataset.select(range(args.train_size,args.train_size+args.test_size))
    test_dataset = test_dataset.map(
        lambda x: data_processor.format_and_tokenize(x, tokenizer, args.max_length), 
        batched=True
        )
    test_dataset.set_format(type='torch')
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        # collate_fn=lambda batch: collate_fn(batch, tokenizer)
    )

    # load clustering result
    if args.mode=='deck' and args.cluster_mode == 'file' and args.cluster_folder is not None:
        clustering_result = load_all_partition_to_tensor(
            os.path.join(args.cluster_folder, 'partition'), 
            model.config.num_hidden_layers,
            model.config.intermediate_size
            )
    elif args.mode=='deck' and args.cluster_mode == 'data':
        clustering_list = clustering_process(
            model,
            train_dataloader,
            tokenizer,
            cluster_store_path=os.path.join(args.cluster_folder, 'partition'),
            eigen_store_path=os.path.join(args.cluster_folder, 'eigen'),
            graph_store_path=os.path.join(args.cluster_folder, 'graph'),
            end_batch_ind=args.end_batch,
            min_rank=args.min_rank,
            k_nn=args.k_nn,
            k_cluster=args.k_cluster,
            sigma=args.sigma,
            cluster_from_layer=args.cluster_from_layer,
        )
        clustering_result = partition_list_to_tenosr(
            clustering_list,
            model.config.num_hidden_layers,
            model.config.intermediate_size
            )
    elif args.mode=='deck_no_cluster':
        if args.graph_path is not None:
            clustering_result = np.load(args.graph_path)
        else:
            clustering_result = get_projection_activation_graphs(
                model,
                train_dataloader,
                tokenizer,
            )
    else:
        clustering_result = None
        
    # 开始训练
    if args.mode != 'gmt':
        train(
            model, 
            train_dataloader, 
            test_dataloader, 
            args.epochs, 
            args.learning_rate, 
            args.p,
            device, 
            args.save_path,
            mode=args.mode,
            neighbor_p = args.neighbor_p,
            clustering_result=clustering_result,
            mlp_keyword=args.mlp_keyword,
            tokenizer=tokenizer)
    else:
        train_gmt(
            model, 
            train_dataloader, 
            test_dataloader, 
            args.epochs, 
            args.learning_rate, 
            args.p,
            device, 
            args.save_path,
            mode=args.mode,
            mlp_keyword=args.mlp_keyword,
            tokenizer = tokenizer)

if __name__ == "__main__":
    main()
