import os
import random
import numpy as np
import torch
import datetime
import evaluate
from transformers import (
    PretrainedConfig,
    BertTokenizer, 
    BertConfig,
    T5Tokenizer,
    T5Config,
    DataCollatorWithPadding,
)
from transformers.models.bert.modeling_bert import (
    BertForSequenceClassification,
)
from transformers.utils import logging
from datasets import DatasetDict, Dataset, load_from_disk, load_metric
from typing import Optional, Dict, List
from trainer import BertGlueTrainer
from utils import BertUtils
from loguru import logger

import argparse

task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

format_time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")

def setup_seed(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default="glue")
    parser.add_argument('--task_name', type=str, default=None)
    parser.add_argument('--model_name', type=str, default="bert-base-uncased")
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--seed', type=int, default=42)

    # for train
    parser.add_argument('--lr', type=float, default=2e-5)
    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--adam_epsilon', type=float, default=1e-8)
    parser.add_argument('--num_warump_steps', type=int, default=0)
    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--num_epochs', type=int, default=3)
    parser.add_argument('--max_grad_norm', type=float, default=1.0)
    parser.add_argument('--shuffle', action='store_true')

    return parser.parse_args()

def run():
    args = parse_args()
    assert(args.dataset == "glue")

    setup_seed(args.seed)
    try:
        raw_datasets = load_from_disk('../../.cache/datasets/{}/{}'.format(args.dataset, args.task_name))
    except:
        logger.error('dataset: [{}] does not exist'.format(args.dataset))
        raise
    assert isinstance(raw_datasets, DatasetDict)

    # Labels
    if args.task_name is not None:
        is_regression = args.task_name == "stsb"
        if not is_regression:
            label_list = raw_datasets["train"].features["label"].names
            num_labels = len(label_list)
        else:
            num_labels = 1
    else:
        raise
    
    tokenizer = BertTokenizer.from_pretrained(args.model_name, cache_dir=f"../../.cache/models")
    config = BertConfig.from_pretrained(args.model_name, cache_dir=f"../../.cache/models", num_labels=num_labels)
    model = BertForSequenceClassification.from_pretrained(
        args.model_name,
        config=config,
        cache_dir=f"../../.cache/models",
    )
    comp_utils = BertUtils()

    preprocess_function = comp_utils.init_preprocess_function(args, tokenizer)

    raw_datasets = raw_datasets.map(
        preprocess_function,
        batched=True,
        desc="Running tokenizer on dataset",
    )
    for key, dataset in raw_datasets.items():
        raw_datasets[key] = dataset.remove_columns(set(dataset.column_names) - set(comp_utils.columns))
    
    datasets: Dict[str, Dataset] = {
        'train': raw_datasets['train'],
        'validation': raw_datasets["validation_matched" if args.task_name == "mnli" else "validation"]
    }
    
    data_collator = DataCollatorWithPadding(tokenizer)

    # metric = evaluate.load(args.dataset, args.task_name)
    metric = load_metric("metric/" + args.dataset + ".py", args.task_name)
    trainer = BertGlueTrainer(args, model, datasets, tokenizer, data_collator, metric)

    trainer.train(model)

    # 放到tmp文件夹防止模型覆盖
    model_path = "../../.cache/checkpoint/tmp/{}-{}-{}.pth".format(args.dataset, args.task_name, args.model_name)
    torch.save(model.state_dict(), model_path)


if __name__ == '__main__':
    run()
