import numpy as np
import torch

from model import *
import argparse
import os
import shutil, random
from train import train_wikitext
from datasets import load_from_disk
from torch.utils.data import DataLoader

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

def load_tokenized_dataset(dataset_path = None, train_dataset_path = None, test_dataset_path = None, batch_size = 4, 
                           need_return_vocab_size = False, world_size = 1,
                           train_length = None, test_length = None, only_dataset = None, test_samples = None):
    """
    Load tokenized dataset from the specified path.
    
    Args:
        dataset_path (str): Path to the dataset directory.
        train_dataset_path (str, optional): Path to the training dataset. Defaults to None.
        test_dataset_path (str, optional): Path to the testing dataset. Defaults to None.
    
    Returns:
        datasets.Dataset: Loaded tokenized dataset.
    """
    if train_dataset_path is not None:
        train_dataset = load_from_disk(train_dataset_path)
    else:
        train_dataset = None

    if test_dataset_path is not None:
        test_dataset = load_from_disk(test_dataset_path)
    else:
        test_dataset = None
        
    if train_dataset is None and test_dataset is None:
        dataset = load_from_disk(dataset_path)
        train_dataset = dataset['train'] if 'train' in dataset else None
        test_dataset = dataset['test'] if 'test' in dataset else None
    if test_samples is not None and test_dataset is not None:
        # 如果指定了test_samples，则从测试集中随机抽取test_samples个样本
        if test_samples <= len(test_dataset):
            test_dataset = test_dataset.shuffle(seed=42).select(range(test_samples))
        else:
            print(f"Warning: test_samples ({test_samples}) is greater than the size of the test dataset ({len(test_dataset)}). Using the entire test dataset instead.")
            test_samples = len(test_dataset) // world_size * world_size
            test_dataset = test_dataset.shuffle(seed=42).select(range(test_samples))
    train_dataset.set_format(
        type="torch", 
        columns=["input_ids",]
    )
    test_dataset.set_format(
        type="torch", 
        columns=["input_ids",]
    )
    train_dataset_input_ids = train_dataset['input_ids']
    test_dataset_input_ids = test_dataset['input_ids'] 
    lens = [len(test_dataset_input_ids[i]) for i in range(len(test_dataset_input_ids))]
    print(min(lens), max(lens), sum(lens)/len(lens))
    if train_length is not None:
        # 如果指定了train_length，则将输入序列 reshape；例如原序列长度均为 2048，新序列长度为 128，则 reshape 后每个样本的长度为 128
        train_dataset_input_ids = train_dataset_input_ids.reshape(-1, train_length)
        test_length = train_length if test_length is None else test_length
        batch_nums = test_dataset_input_ids.numel() // test_length
        test_dataset_input_ids = test_dataset_input_ids.flatten()[:batch_nums * test_length].reshape(-1, test_length)
        # test_dataset_input_ids = test_dataset_input_ids.reshape(-1, test_length)
    
    if only_dataset is not None:
        return_tuple = ()
        if only_dataset == 'train':
            test_dataloader = DataLoader(
                test_dataset_input_ids, 
                batch_size=batch_size, 
                shuffle=False, 
            ) 
            return_tuple = (test_dataloader, train_dataset_input_ids)
        elif only_dataset == 'test':
            dataloader = DataLoader(
                train_dataset_input_ids, 
                batch_size=batch_size, 
                shuffle=True, 
            )
            return_tuple = (dataloader, test_dataset_input_ids)
        elif only_dataset == 'both':
            return_tuple = (train_dataset_input_ids, test_dataset_input_ids)
        if need_return_vocab_size:
            vocab_size = max(train_dataset_input_ids.max() + 1, test_dataset_input_ids.max() + 1)
            return_tuple += (vocab_size,)
        return return_tuple
    else:        
        
        dataloader = DataLoader(
            train_dataset_input_ids, 
            batch_size=batch_size, 
            shuffle=True, 
        )
        test_dataloader = DataLoader(
            test_dataset_input_ids, 
            batch_size=batch_size, 
            shuffle=False, 
        )
        if need_return_vocab_size:
            vocab_size = max(train_dataset_input_ids.max() + 1, test_dataset_input_ids.max() + 1)
            return dataloader, test_dataloader, vocab_size
        
        return dataloader, test_dataloader


def main_wikitext(args, **kwargs):
    # 设置随机种子
    setup_seed(args.seed)
    train_dataloader, test_loader, vocab_size = load_tokenized_dataset(
        dataset_path='./data/ripe_wikitext-103-v1',
        need_return_vocab_size=True,
        # train_length=128 
    )

    args.vocab_size = vocab_size.item()
    args.seq_len = len(train_dataloader.dataset[0])
    args.max_pos = args.seq_len
    train_wikitext(args, datas=None, train_dataloader=train_dataloader, test_dataloader=test_loader)


def dmode_prepocess(dmode):
    """
    将传入的 dmode 进行预处理。输入的 dmode 存在几种可能:
    1. 一个字符串化的数字构成的列表的字符串，如 '['1', '2', '3', '4']'
    2. 一个字符串化的数字构成的列表的嵌套的字符串，如 '['12', '21', ['34', '43']]'
    均转化为单纯的列表
    """
    return eval(dmode)


if __name__ == '__main__':

    parser = argparse.ArgumentParser(description="Pytorch distributed")

    # 数据集参数; test_data 允许多个输入
    parser.add_argument('-mode', '--mode', type = str, default = 'default', choices=['default', 'add_task', 'star_graph','knowledge_graph','wikitext'], help='模式') 
    parser.add_argument('-train_data', '--train_data', type = str, default = None, help='训练数据集路径')
    parser.add_argument('-test_data', '--test_data', type = str, default = None, help='测试数据集路径')
    parser.add_argument('-valid_data', '--valid_data', type = str, default = None, help='验证数据集路径')
    parser.add_argument('-LTP_last_pos', '--LTP_last_pos', type = int, default = 1, help='LTP 的 last token position')
    parser.add_argument('-symmetric_data', '--symmetric_data', type = str, default = None, help='对称数据集路径')
    parser.add_argument('-data_size', '--data_size', type = int, default = 1000) 
    parser.add_argument('-sl', '--seq_len', type = int, default = 10, help='句子长度')
    parser.add_argument('-dmin', '--data_min', type = int, default = 20, help='数据集中数据的最小值')
    parser.add_argument('-dmax', '--data_max', type = int, default = 100, help='数据集中数据的最大值')
    parser.add_argument('-bs', '--batch_size', type = int, default = 10) 
    parser.add_argument('-seed', '--seed', type = int, default = 1)  

    parser.add_argument('-dmode', '--data_mode', nargs='*', type=str, default = [1], help='各类数据集的模式，不同任务中的数据集模式不同')
    parser.add_argument('-dp', '--data_percent', nargs='*', type=float, default = [1], help='各类数据集占比')
    parser.add_argument('-dn', '--data_name', nargs='*', type=str, default = ['full data'], help='各类数据集名称')
    parser.add_argument('-dtrain', '--data_train', nargs='*', type=int, default = [0], help='该类是否参与训练')
    parser.add_argument('-dshow', '--data_show', nargs='*', type=int, default = [0], help='画图时是否显示该类数据集，1表示显示，0表示不显示')
    parser.add_argument('-dataset_on_cuda', '--dataset_on_cuda', type = int, default = 1, help='dataset 是否在 cuda 上')
    
    # Debugging options
    parser.add_argument('-debug_data_only', '--debug_data_only', type = int, default = 0, 
                      help='只生成数据，不训练模型，用于调试数据生成过程')

    # 目标函数
    parser.add_argument('-func', '--target', type = str, default = '', help='任务')

    # 网络结构与超参数
    parser.add_argument('-m', '--model', type = str, default = 'GPT', help='模型') 
    parser.add_argument('-vs', '--vocab_size', type = int, default = 201) 
    parser.add_argument('-dm', '--d_model', type = int, default = 400)
    parser.add_argument('-d_ff', '--d_feedforward', type = int, default = 1200)
    parser.add_argument('-dk', '--d_k', type = int, default = 64)
    parser.add_argument('-dv', '--d_v', type = int, default = 64)
    parser.add_argument('-nl', '--n_layers', type = int, default = 2)
    parser.add_argument('-nh', '--n_heads', type = int, default = 1)
    parser.add_argument('-cl', '--clip', type = float, default = 1, help='梯度裁剪')
    # parser.add_argument('-sr', '--std_rate', type = float, default = 1, help='标准差的幂次') 
    
    parser.add_argument('-dk_list', '--dk_list', nargs='*', type=int, default = None, help='d_k 的 LIST 版本')
    
    # 训练超参数
    parser.add_argument('-ne', '--n_epoch', type = int, default = 3000) 
    parser.add_argument('-lr', '--lr', type = float, default = 1.e-4, help='初始学习率') 
    parser.add_argument('-op', '--optim', type=str, default = 'AdamW', help='优化器')  
    parser.add_argument('-scheduler', '--scheduler', type = str, default = 'StepLR', help='调度器')
    parser.add_argument('-eps', '--eps', type = float, default = 1.e-8, help='adam epsilon') 
    parser.add_argument('-beta1', '--beta1', type = float, default = 0.9, help='adam beta1')
    parser.add_argument('-beta2', '--beta2', type = float, default = 0.999, help='adam beta2') 
    parser.add_argument('-type', '--dtype', type = str, default = 'float32', help='数据类型')
    
    parser.add_argument('-lds', '--lr_decay_step', type = int, default = 20, help='使用StepLR调度器时，每隔多少epoch学习率衰减') 
    parser.add_argument('-ldr', '--lr_decay_rate', type = float, default = 0.9, help='使用StepLR调度器时，学习率变为原来的多少倍') 
    
    parser.add_argument('-optim_total_epoch', '--optim_total_epoch', type = int, default = 400, help='使用GradualWarmupScheduler时的预热的周期数')
    parser.add_argument('-optim_multiplier', '--optim_multiplier', type = float, default = 5, help='使用GradualWarmupScheduler时的最大学习率与初始学习率的比值')
    parser.add_argument('-optim_T_max', '--optim_T_max', type = int, default = 4000, help='使用CosineAnnealingLR时的周期长度，即从当前学习率下降到最小学习率所需的epoch，若继续训练则会按照cosine继续上升到最大学习率，然后再下降')
    parser.add_argument('-optim_eta_min', '--optim_eta_min', type = float, default = 1e-5, help='使用CosineAnnealingLR下降到的最小学习率')
    parser.add_argument('-use_train_with_spike_reshuffle', '--use_train_with_spike_reshuffle', type = str, default = 'NO', help='是否使用 train_with_spike_reshuffle')
    parser.add_argument('-prediction_pattern', '--prediction_pattern',  type=str, default='LTP', choices = ['NTP', 'LTP','CTP','CoT'], help='选择 last token prediction 还是 next token prediction')
    parser.add_argument('-label_smoothing', '--label_smoothing', type = float, default = 0., help='标签平滑')
    parser.add_argument('-weight_decay', '--weight_decay', type = float, default = 0., help='权重衰减')
    # 保存、输出信息和画图的间隔
    parser.add_argument('-sme', '--save_model_epoch', type = int, default = 100, help='每隔多少epoch保存一次模型') 
    parser.add_argument('-ple', '--print_loss_epoch', type = int, default = 10, help='每隔多少epoch输出一次loss')
    parser.add_argument('-pae', '--print_acc_epoch', type = int, default = 100, help='每隔多少epoch输出一次acc')
    parser.add_argument('-plae', '--plot_loss_acc_epoch', type = int, default = 500, help='每隔多少epoch画一次loss和acc')
    parser.add_argument('-save_optimizer', '--save_optimizer', type = int, default = 0, help='是否需要保存优化器参数')
    parser.add_argument('-SAM_rho', '--SAM_rho', type = float, default = 0.1, help='Sharp Aware Minimization 的 rho 参数')
    
    # 前缀与后缀
    parser.add_argument('-prefix', '--prefix', type = str, default = ' ', help='文件夹前缀')
    parser.add_argument('-suffix', '--suffix', type = str, default = ' ', help='文件夹后缀')
    parser.add_argument('-rn', '--runner', type = str, default = ' ', help='程序运行者')

    # 大文件夹的后缀
    parser.add_argument('-dir_suffix', '--dir_suffix', type = str, default = ' ', help='上级文件夹的后缀')

    # scaling law
    parser.add_argument('-tm', '--train_method', type = str, default = 'train_last_token', help='训练方式，写train_scaling_law则调用train_scaling_law.py进行训练')
    parser.add_argument('-n_batch', '--n_batch', type = int, default = 10000, help='仅在train_scaling_law中使用，表示训练多少个batch') 
    parser.add_argument('-gdm', '--gen_data_mode', type = str, default = 'fix', help='仅在train_scaling_law中使用，表示生成数据的模式，可选on_the_fly或fix')

    #condense
    parser.add_argument('-sr', '--std_rate', type = float, default = 1, help='标准差的幂次') 

    parser.add_argument('-embedding_std', '--embedding_std', type = float, default = 0.5, help='标准差的幂次') 
    parser.add_argument('-qk_std', '--qk_std', type = float, default = 0.5, help='标准差的幂次') 
    parser.add_argument('-vo_std', '--vo_std', type = float, default = 0.5, help='标准差的幂次') 
    parser.add_argument('-mlp_std', '--mlp_std', type = float, default = 0.5, help='标准差的幂次') 
    parser.add_argument('-all_std', '--all_std', type = float, default = None, help='所有标准差的幂次，若不为None，则embedding_std、qk_std、vo_std、mlp_std均为该值')

    parser.add_argument('-freeze_embedding', '--freeze_embedding', type = int, default = 0, help='是否冻结 embedding 层，1表示冻结，0表示不冻结')
    parser.add_argument('-activation', '--activation', type = str, default = 'gelu', help='激活函数')
    parser.add_argument('-embedding_mean', '--embedding_mean', type = float, default = 0.5, help='embedding 层的均值')
    # # gpu
    # parser.add_argument('-gpu', '--gpu', type = int, default = 0, help='使用的gpu编号')

    # 解析已知的参数和未知的参数
    args, remaining = parser.parse_known_args()
    print(remaining)
    # 将未知的参数转化为字典
    remaining_dict = {}
    for i in range(0, len(remaining), 2):
        key = remaining[i].lstrip('-')
        value = remaining[i+1]
        remaining_dict[key] = value
    # remaining_dict 合并到 args 中
    args.__dict__.update(remaining_dict)

    # 生成主文件夹目录
    working_dir = f'{args.target}-N_{int(args.data_size)}'
    
    if args.prefix != ' ':
        working_dir = f'{args.prefix}-{working_dir}'
    if args.suffix != ' ':
        working_dir = f'{working_dir}-{args.suffix}'
    
    if args.dir_suffix != ' ':
        if hasattr(args, 'proj_name') and args.proj_name != ' ':
            args.working_dir = f'./result/{args.proj_name}/{args.model}_{args.dir_suffix}-{working_dir}'
        else:
            args.working_dir = f'./result/{args.model}_{args.dir_suffix}-{working_dir}'
    else:
        if hasattr(args, 'proj_name') and args.proj_name != ' ':
            args.working_dir = f'./result/{args.proj_name}/{args.model}-{working_dir}'
        else:
            # 如果没有 proj_name，则直接使用 model 和 working_dir
            args.working_dir = f'./result/{args.proj_name}/{args.model}-{working_dir}'

    # 保存源代码
    for file in ['pic', 'loss', 'src', 'data', 'model']:
        os.makedirs(f'{args.working_dir}/{file}', exist_ok=True)
    for file in ['main.py', 'data.py', 'train.py', args.script_file_name]:
        shutil.copy(file, f'{args.working_dir}/src/{file}')
    for dir in ['utils', 'model', 'data_generator']:
        shutil.copytree(dir, f'{args.working_dir}/src/{dir}', dirs_exist_ok=True)   

    # vocab_size 增加 10 放置 START 和 END token
    args.vocab_size += 10
    main_wikitext(args)