import os
import re
import random
import numpy as np
import torch
import datetime
import evaluate
from transformers import (
    PretrainedConfig,
    T5Tokenizer,
    T5Config,
    DataCollatorForSeq2Seq
)
from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration
from models.modeling_t5 import (
    ReduT5ForConditionalGeneration
)
from transformers.utils import logging
from datasets import DatasetDict, Dataset, load_from_disk, load_metric
from typing import Optional, Dict, List
from trainer import T5GlueTrainer
from loguru import logger
from utils import T5Utils, GlueUtils

import argparse

task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

format_time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")

def setup_seed(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default="glue")
    parser.add_argument('--task_name', type=str, default=None)
    parser.add_argument('--model_name', type=str, default="t5-base")
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--adam_epsilon', type=float, default=1e-8)
    parser.add_argument('--num_warump_steps', type=int, default=0)
    parser.add_argument('--max_grad_norm', type=float, default=1.0)
    parser.add_argument('--shuffle', action='store_true')
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--lambda_logits', type=float, default=0.5)
    parser.add_argument('--lambda_hiddens', type=float, default=0.5)

    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--num_epochs', type=int, default=2)
    parser.add_argument('--batch_size', type=int, default=16)

    parser.add_argument('--use_subset', action="store_true")
    parser.add_argument('--data_sample_num', type=int, default=2000)
    parser.add_argument('--comp_mode', type=int, default=0)
    parser.add_argument('--r_model', type=int, default=576)
    parser.add_argument('--r_kv', type=int, default=48)
    parser.add_argument('--r_ff', type=int, default=2304)
    parser.add_argument('--en_num_layers', type=int, default=12)
    parser.add_argument('--de_num_layers', type=int, default=12)

    parser.add_argument('--svd_results_dir', type=str, default="svd_results")
    parser.add_argument('--use_random', action="store_true")
    return parser.parse_args()


def run():
    args = parse_args()
    assert(args.dataset == "glue")

    setup_seed(args.seed)
    logger.add("log/run_comp-task_name[{}]-time[{}].log".format(args.task_name, format_time))
    for key, value in vars(args).items():
        logger.info("{} -> {}".format(key, value))
    try:
        raw_datasets = load_from_disk('../../.cache/datasets/{}/{}'.format(args.dataset, args.task_name))
    except:
        logger.error('dataset: [{}] does not exist'.format(args.dataset))
        raise
    assert isinstance(raw_datasets, DatasetDict)

    # Labels
    if args.task_name is not None:
        is_regression = args.task_name == "stsb"
        if not is_regression:
            label_list = raw_datasets["train"].features["label"].names
            num_labels = len(label_list)
        else:
            num_labels = 1
    else:
        raise
    
    tokenizer = T5Tokenizer.from_pretrained(args.model_name, cache_dir=f"../../.cache/models")
    config = T5Config.from_pretrained(
        args.model_name, 
        cache_dir=f"../../.cache/models"
    )
    config.comp_mode = args.comp_mode
    config.r_model = args.r_model
    config.r_kv = args.r_kv
    config.r_ff = args.r_ff
    config.en_num_layers = args.en_num_layers
    config.de_num_layers = args.de_num_layers

    checkpoint_path = "../../.cache/checkpoint/{}-{}-{}.pth".format(args.dataset, args.task_name, args.model_name)
    state_dict: Dict[str, torch.Tensor] = torch.load(checkpoint_path, map_location='cpu')

    model = ReduT5ForConditionalGeneration(config)
    teacher = T5ForConditionalGeneration(config)
    teacher.load_state_dict(state_dict)
    comp_utils = T5Utils()
    glue_utils = GlueUtils()

    preprocess_function = glue_utils.init_glue_preprocess_function(args, tokenizer)
    
    for key in list(raw_datasets.keys()):
        if key.startswith("test"):
            del raw_datasets[key]
    
    raw_datasets = raw_datasets.map(
        preprocess_function,
        batched=True,
        desc="Running tokenizer on dataset",
    )
    for key, dataset in raw_datasets.items():
        raw_datasets[key] = dataset.remove_columns(set(dataset.column_names) - set(glue_utils.columns))
        
    datasets: Dict[str, Dataset] = {
        'train': raw_datasets['train'],
        'validation': raw_datasets["validation_matched" if args.task_name == "mnli" else "validation"]
    }
    
    data_collator = DataCollatorForSeq2Seq(tokenizer)

    path = "{}/{}-{}-{}".format(
        args.svd_results_dir,
        args.dataset,
        args.task_name,
        args.data_sample_num
    )
    param_path = "{}_{}_{}_{}_params.pt".format(
        args.model_name,
        args.r_model,
        args.r_kv,
        args.comp_mode
    )
    compression_params = torch.load(os.path.join(path, param_path), map_location="cpu")

    if args.use_random:
        encoder_proj: torch.Tensor = torch.randn_like(compression_params["encoder_proj"])
        decoder_proj: torch.Tensor = torch.randn_like(compression_params["decoder_proj"])
        compression_params["encoder_proj"] = encoder_proj
        compression_params["decoder_proj"] = decoder_proj

    comp_utils.load_model_params(model, state_dict, compression_params)

    if args.use_subset:
        for param in model.shared.parameters():
            param.requires_grad = False
        for param in model.lm_head.parameters():
            param.requires_grad = False

    metric = load_metric("metric/" + args.dataset + ".py", args.task_name)
    trainer = T5GlueTrainer(args, model, datasets, tokenizer, data_collator, metric, teacher=teacher)

    validation_resutls = trainer.evaluate(
        "validation",
        datasets["validation"],
        eval_teacher=True
    )
    for metric, rest in validation_resutls.items():
        logger.info("teacher-[validation] {:<6}: {:.5}".format(metric, rest))

    if args.use_subset:
        args.data_sample_num = min(args.data_sample_num, len(datasets['train']))
        subset_indices = np.random.choice(
            len(datasets['train']), args.data_sample_num, replace=False).tolist()
    else:
        subset_indices = None

    trainer.train(eval=True, eval_before_train=True, subset_indices=subset_indices)

    if config.num_layers != config.en_num_layers or config.num_layers != config.de_num_layers:
        model.drop_layers()
        trainer.train(eval=True, eval_before_train=False, subset_indices=subset_indices)

    validation_resutls = trainer.evaluate('final-validation', trainer.datasets['validation'])
    for metric, rest in validation_resutls.items():
        logger.info("[final-validation] {:<6}: {:.5}".format(metric, rest))

    # 放到tmp文件夹防止模型覆盖
    model_path = "../../.cache/checkpoint/tmp/{}-{}-{}-TCSP-{}-{}-{}.pth"\
        .format(args.dataset, args.task_name, args.model_name,
                args.r_model, args.r_kv, args.r_ff)
    torch.save(model.state_dict(), model_path)


if __name__ == '__main__':
    run()

