import os
import random
import numpy as np
import torch
import datetime
import evaluate
from transformers import (
    PretrainedConfig,
    BertTokenizer, 
    BertConfig,
    T5Tokenizer,
    T5Config,
    DataCollatorForSeq2Seq
)
from transformers.models.t5.modeling_t5 import T5ForConditionalGeneration
from transformers.utils import logging
from datasets import DatasetDict, Dataset, load_from_disk, load_metric
from typing import Optional, Dict, List
from trainer.t5_glue_trainer import T5GlueTrainer
from loguru import logger
from utils import T5Utils, GlueUtils

import argparse

format_time = datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")

def setup_seed(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default="glue")
    parser.add_argument('--task_name', type=str, default=None)
    parser.add_argument('--model_name', type=str, default="t5-base")
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--seed', type=int, default=42)

    # for collect
    parser.add_argument('--use_disk', action="store_true")
    parser.add_argument('--use_tmp',  action="store_true")
    parser.add_argument('--data_sample_num', type=int, default=2000)
    parser.add_argument('--token_sample_num', type=int, default=2000)
    parser.add_argument('--per_sample_token_num', type=int, default=10)
    parser.add_argument('--svd_results_dir', type=str, default="svd_results")
    parser.add_argument('--comp_mode', type=int, default=0)
    parser.add_argument('--r_model', type=int, default=576)
    parser.add_argument('--r_kv', type=int, default=48)

    return parser.parse_args()

def run():
    args = parse_args()
    assert(args.dataset == "glue")

    setup_seed(args.seed)
    try:
        raw_datasets = load_from_disk('../../.cache/datasets/{}/{}'.format(args.dataset, args.task_name))
    except:
        logger.error('dataset: [{}] does not exist'.format(args.dataset))
        raise
    assert isinstance(raw_datasets, DatasetDict)

    # Labels
    if args.task_name is not None:
        is_regression = args.task_name == "stsb"
        if not is_regression:
            label_list = raw_datasets["train"].features["label"].names
            num_labels = len(label_list)
        else:
            num_labels = 1
    else:
        raise

    tokenizer = T5Tokenizer.from_pretrained(args.model_name, cache_dir=f"../../.cache/models")
    config = T5Config.from_pretrained(args.model_name, cache_dir=f"../../.cache/models")
    model_path = "../../.cache/checkpoint/{}-{}-{}.pth".format(args.dataset, args.task_name, args.model_name)
    model = T5ForConditionalGeneration(config)
    model.load_state_dict(torch.load(model_path, map_location='cpu'))

    comp_utils = T5Utils()
    glue_utils = GlueUtils()

    preprocess_function = glue_utils.init_glue_preprocess_function(args, tokenizer)

    for key in list(raw_datasets.keys()):
        if key.startswith("test"):
            del raw_datasets[key]

    raw_datasets = raw_datasets.map(
        preprocess_function,
        batched=True,
        desc="Running tokenizer on dataset",
    )
    for key, dataset in raw_datasets.items():
        raw_datasets[key] = dataset.remove_columns(set(dataset.column_names) - set(glue_utils.columns))
        
    datasets: Dict[str, Dataset] = {
        'train': raw_datasets['train'],
        'validation': raw_datasets["validation_matched" if args.task_name == "mnli" else "validation"]
    }
    
    data_collator = DataCollatorForSeq2Seq(tokenizer)
    args.data_sample_num = min(args.data_sample_num, len(datasets["train"]))
    indices: np.ndarray = np.random.choice(
        len(datasets['train']), args.data_sample_num, replace=False).tolist()
    datasets['train'] = datasets['train'].select(indices)

    trainer = T5GlueTrainer(args, model, datasets, tokenizer, data_collator, None)

    data_args = args
    model_args = args
    params = comp_utils.collect(data_args, model_args, model, trainer)

    dir_path = f"{args.svd_results_dir}/{args.dataset}-{args.task_name}-{data_args.data_sample_num}"
    os.makedirs(dir_path, exist_ok=True)

    param_path = "{}_{}_{}_{}_params.pt".format(
        model_args.model_name,
        model_args.r_model,
        model_args.r_kv,
        model_args.comp_mode
    )
    torch.save(params, os.path.join(dir_path, param_path))


if __name__ == '__main__':
    run()
