import os
import re
import random
import numpy as np
import torch
import datetime
import evaluate
from transformers import (
    PretrainedConfig,
    BertTokenizer,
    BertConfig,
    T5Tokenizer,
    T5Config,
    DataCollatorWithPadding
)
from transformers.models.bert.modeling_bert import BertForSequenceClassification
from models.modeling_bert import (
    ReduBertForSequenceClassification,
)
from transformers.utils import logging
from datasets import DatasetDict, Dataset, load_from_disk, load_metric
from typing import Optional, Dict, List
from trainer.bert_glue_trainer import BertGlueTrainer
from loguru import logger
from utils import BertUtils

import argparse

task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

format_time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")

def setup_seed(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default="glue")
    parser.add_argument('--task_name', type=str, default=None)
    parser.add_argument('--model_name', type=str, default="bert-base-uncased")
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--adam_epsilon', type=float, default=1e-8)
    parser.add_argument('--num_warump_steps', type=int, default=0)
    parser.add_argument('--max_grad_norm', type=float, default=1.0)
    parser.add_argument('--shuffle', action='store_true')
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--lambda_logits', type=float, default=0.5)
    parser.add_argument('--lambda_hiddens', type=float, default=0.5)
    parser.add_argument('--use_subset', action="store_true")

    parser.add_argument('--lr', type=float, default=1e-5)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--num_epochs', type=int, default=2)

    parser.add_argument('--data_sample_num', type=int, default=2000)
    parser.add_argument('--redu_hidden_size', type=int, default=576)
    parser.add_argument('--redu_attention_size', type=int, default=768)
    parser.add_argument('--redu_intermediate_size', type=int, default=3072)

    return parser.parse_args()


def run():
    args = parse_args()
    assert(args.dataset == "glue")

    setup_seed(args.seed)
    logger.add("log/run_comp-task_name[{}]-time[{}].log".format(args.task_name, format_time))
    for key, value in vars(args).items():
        logger.info("{} -> {}".format(key, value))
    try:
        raw_datasets = load_from_disk('../../.cache/datasets/{}/{}'.format(args.dataset, args.task_name))
    except:
        logger.error('dataset: [{}] does not exist'.format(args.dataset))
        raise
    assert isinstance(raw_datasets, DatasetDict)

    # Labels
    if args.task_name is not None:
        is_regression = args.task_name == "stsb"
        if not is_regression:
            label_list = raw_datasets["train"].features["label"].names
            num_labels = len(label_list)
        else:
            num_labels = 1
    else:
        raise

    tokenizer = BertTokenizer.from_pretrained(args.model_name, cache_dir=f"../../.cache/models")
    config = BertConfig.from_pretrained(
        args.model_name,
        cache_dir=f"../../.cache/models",
        num_labels=num_labels,
    )
    config.redu_hidden_size = args.redu_hidden_size
    config.redu_attention_size = args.redu_attention_size
    config.redu_intermediate_size = args.redu_intermediate_size

    comp_utils = BertUtils()
    preprocess_function = comp_utils.init_preprocess_function(args, tokenizer)

    raw_datasets = raw_datasets.map(
        preprocess_function,
        batched=True,
        desc="Running tokenizer on dataset",
    )
    for key, dataset in raw_datasets.items():
        raw_datasets[key] = dataset.remove_columns(set(dataset.column_names) - set(comp_utils.columns))

    datasets: Dict[str, Dataset] = {
        'train': raw_datasets['train'],
        'validation': raw_datasets["validation_matched" if args.task_name == "mnli" else "validation"]
    }

    data_collator = comp_utils.DataCollator(tokenizer)

    checkpoint_path = "../../.cache/checkpoint/{}-{}-{}.pth".format(args.dataset, args.task_name, args.model_name)
    state_dict: Dict[str, torch.Tensor] = torch.load(checkpoint_path, map_location='cpu')

    comp_utils = BertUtils()
    model = ReduBertForSequenceClassification(config)
    teacher = BertForSequenceClassification(config)
    teacher.load_state_dict(state_dict)

    path = "svd_results/{}-{}-{}".format(
        args.dataset, args.task_name, args.data_sample_num)
    compression_params = torch.load(
        os.path.join(path, f"{args.model_name}_{args.redu_hidden_size}_{args.redu_attention_size}_params.pt"),
        map_location="cpu"
    )
    comp_utils.load_model_params(model, state_dict, compression_params)

    metric = load_metric("metric/" + args.dataset + ".py", args.task_name)
    trainer = BertGlueTrainer(args, model, datasets, tokenizer, data_collator, metric, teacher=teacher)

    validation_resutls = trainer.evaluate(
        "validation",
        datasets["validation"],
        eval_teacher=True
    )
    for metric, rest in validation_resutls.items():
        logger.info("teacher-[validation] {:<6}: {:.5}".format(metric, rest))

    if args.use_subset:
        args.data_sample_size = min(args.data_sample_size, len(datasets['train']))
        subset_indices: np.ndarray = np.random.choice(
            len(datasets['train']), args.data_sample_size, replace=False).tolist()
        model.frozen_params()
    else:
        subset_indices = None

    trainer.train(eval=True, eval_before_train=True, subset_indices=subset_indices)

    # 放到tmp文件夹防止模型覆盖
    model_path = "../../.cache/checkpoint/tmp/{}-{}-{}-TCSP-{}-{}-{}.pth"\
        .format(args.dataset, args.task_name, args.model_name,
                args.redu_hidden_size, args.redu_attention_size, args.redu_intermediate_size)
    torch.save(model.state_dict(), model_path)


if __name__ == '__main__':
    run()
