"""Finetuning the library models for sequence classification on GLUE."""

import dataclasses
import logging
import os
import sys
from dataclasses import dataclass, field
from typing import Callable, Dict, Optional, List
import torch

import numpy as np

import transformers
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction, GPT2LMHeadModel
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import TrainingArguments, set_seed
from src.gpt_trainer import gptTrainer

from transformers import Trainer
from transformers.hf_argparser import HfArgumentParser

# from tools.hf_argparser import HfArgumentParser
from src.dataset import FewShotDataset
from src.models import BertForPromptFinetuning, RobertaForPromptFinetuning, resize_token_type_embeddings
# from src.trainer import Trainer
from src.processors import processors_mapping, num_labels_mapping, output_modes_mapping, compute_metrics_mapping, bound_mapping
from src.gptdataset import gptDataset
from run import ModelArguments, DynamicDataTrainingArguments, DynamicTrainingArguments

from filelock import FileLock
from datetime import datetime

from copy import deepcopy
from tqdm import tqdm
import json

logger = logging.getLogger(__name__)

os.environ["WANDB_DISABLED"] = "true"

#print (compute_metrics_mapping['telephone_letters'])
#exit(0)


def main():
    parser = HfArgumentParser((ModelArguments, DynamicDataTrainingArguments, DynamicTrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    
    if 'prompt' in model_args.few_shot_type:
        data_args.prompt = True
    if 'autoregressive' in model_args.few_shot_type:
        data_args.autoregressive = True
        
        
    if training_args.no_train:
        training_args.do_train = False
    if training_args.no_predict:
        training_args.do_predict = False

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )

    # Load prompt/template/mapping file
    if data_args.prompt:
        if data_args.prompt_path is not None:
            assert data_args.prompt_id is not None
            prompt_list = []
            with open(data_args.prompt_path) as f:
                for line in f:
                    line = line.strip()
                    template, mapping = line.split('\t')
                    prompt_list.append((template, mapping))

            data_args.template, data_args.mapping = prompt_list[data_args.prompt_id] 
            logger.info("Specify load the %d-th prompt: %s | %s" % (data_args.prompt_id, data_args.template, data_args.mapping))
        else:
            if data_args.template_path is not None:
                with open(data_args.template_path) as f:
                    data_args.template_list = []
                    for line in f:
                        line = line.strip()
                        if len(line) > 0:
                            data_args.template_list.append(line)

                # Load top-n templates
                if data_args.top_n_template is not None:
                    data_args.template_list = data_args.template_list[:data_args.top_n_template]
                logger.info("Load top-%d templates from %s" % (len(data_args.template_list), data_args.template_path))

                # ... or load i-th template
                if data_args.template_id is not None:
                    data_args.template = data_args.template_list[data_args.template_id]
                    data_args.template_list = None
                    logger.info("Specify load the %d-th template: %s" % (data_args.template_id, data_args.template))

            if data_args.mapping_path is not None:
                assert data_args.mapping_id is not None # Only can use one label word mapping
                with open(data_args.mapping_path) as f:
                    mapping_list = []
                    for line in f:
                        line = line.strip()
                        mapping_list.append(line)

                data_args.mapping = mapping_list[data_args.mapping_id]
                logger.info("Specify using the %d-th mapping: %s" % (data_args.mapping_id, data_args.mapping))

    # Check save path
    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(f"Output directory ({training_args.output_dir}) already exists.")

    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    try:
        num_labels = num_labels_mapping[data_args.task_name]
        output_mode = output_modes_mapping[data_args.task_name]
        logger.info("Task name: {}, number of labels: {}, output mode: {}".format(data_args.task_name, num_labels, output_mode))
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))
    
    
    special_tokens = []

    # Create tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        additional_special_tokens=special_tokens,
        cache_dir=model_args.cache_dir,
    )

    data_cache_dir = data_args.data_dir
    if data_args.autoregressive:
        data_cache_dir += '_autoregressive'
    if not os.path.exists(data_cache_dir):
        os.mkdir(data_cache_dir)
    
    if data_args.autoregressive:
        dataset_class = gptDataset
    else:
        dataset_class = FewShotDataset
    
    
    # Get our special datasets.
    train_dataset = (
        dataset_class(data_args, tokenizer=tokenizer, cache_dir=data_cache_dir, mode="train", use_demo=("demo" in model_args.few_shot_type))
    )

    eval_dataset = (
        dataset_class(data_args, tokenizer=tokenizer, cache_dir=data_cache_dir, mode="dev", use_demo=("demo" in model_args.few_shot_type))
        # if training_args.do_eval
        # else None
    )
    
    if data_args.task_name == "rte":
        test_dataset = eval_dataset
    else:
        test_dataset = (
            dataset_class(data_args, tokenizer=tokenizer, cache_dir=data_cache_dir, mode="test", use_demo=("demo" in model_args.few_shot_type))
            # if training_args.do_predict
            # else None
        )

    if data_args.num_k == 4096:
        eval_datasets = test_dataset
  

    set_seed(training_args.seed)
    
    if not  data_args.autoregressive:
        
        # Create config
        config = AutoConfig.from_pretrained(
            model_args.config_name if model_args.config_name else model_args.model_name_or_path,
            num_labels=num_labels,
            finetuning_task=data_args.task_name,
            cache_dir=model_args.cache_dir,
        )
        
        if 'prompt' in model_args.few_shot_type or (model_args.few_shot_type == 'finetune' and model_args.use_CLS_linearhead == 1):
            if config.model_type == 'roberta':
                model_fn = RobertaForPromptFinetuning
            elif config.model_type == 'bert':
                model_fn = BertForPromptFinetuning
            else:
                raise NotImplementedError
        elif model_args.few_shot_type == 'finetune':
            model_fn = AutoModelForSequenceClassification
        else:
            raise NotImplementedError
        

        model = model_fn.from_pretrained(
            model_args.model_name_or_path,
            from_tf=bool(".ckpt" in model_args.model_name_or_path),
            config=config,
            cache_dir=model_args.cache_dir,
        )

        # For BERT, increase the size of the segment (token type) embeddings
        if config.model_type == 'bert':
            model.resize_token_embeddings(len(tokenizer))
            resize_token_type_embeddings(model, new_num_types=10, random_segment=model_args.random_segment)

        # Pass dataset and argument information to the model
        if data_args.prompt:
            model.label_word_list = torch.tensor(train_dataset.label_word_list).long().cuda()
        if output_modes_mapping[data_args.task_name] == 'regression':
            # lower / upper bounds
            model.lb, model.ub = bound_mapping[data_args.task_name]
        model.model_args = model_args
        model.data_args = data_args
        model.tokenizer = tokenizer
        
        model.initial_parameters_copy = [p.detach().clone() for p in model.parameters()]
        if (model_args.few_shot_type == 'finetune' and model_args.use_CLS_linearhead == 1):
            model.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)

    else:
        config_kwargs = {}
        config = AutoConfig.from_pretrained(
            model_args.config_name if model_args.config_name else model_args.model_name_or_path,
            cache_dir=model_args.cache_dir,
            **config_kwargs
        )

        if 'opt' in model_args.model_name_or_path:
            model_fn = OPTForCausalLM
        elif 'gpt' in model_args.model_name_or_path:
            model_fn = GPT2LMHeadModel
            
        model = model_fn.from_pretrained(
            model_args.model_name_or_path,
            config=config,
            cache_dir=model_args.cache_dir,
        )
        model.parallelize()
    
    
    
    
    # Build metric
    def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
        def compute_metrics_fn(p: EvalPrediction):
            # Note: the eval dataloader is sequential, so the examples are in order.
            # We average the logits over each sample for using demonstrations.
            predictions = p.predictions
            num_logits = predictions.shape[-1]
            logits = predictions.reshape([eval_dataset.num_sample, -1, num_logits])
            logits = logits.mean(axis=0)
            
            if num_logits == 1:
                preds = np.squeeze(logits)
            else:
                preds = np.argmax(logits, axis=1)

            # Just for sanity, assert label ids are the same.
            label_ids = p.label_ids.reshape([eval_dataset.num_sample, -1])
            label_ids_avg = label_ids.mean(axis=0)
            label_ids_avg = label_ids_avg.astype(p.label_ids.dtype)
            assert (label_ids_avg - label_ids[0]).mean() < 1e-2
            label_ids = label_ids[0]
            
            if 'telephone' in task_name :
                return compute_metrics_mapping[task_name]('mnli', preds, label_ids)
            if 'anli' in task_name:
                return compute_metrics_mapping[task_name]('mnli', preds, label_ids)
            if 'imdb' in task_name:
                return compute_metrics_mapping[task_name]('sst-2', preds, label_ids)
            return compute_metrics_mapping[task_name](task_name, preds, label_ids)

        return compute_metrics_fn
    
    
    
    # Initialize our Trainer
    if data_args.autoregressive:
        trainer_class = gptTrainer
        
    else:
        trainer_class = Trainer
    
    
    # Initialize our Trainer
    # training_args.save_total_limit=1
    # training_args.load_best_model_at_end=True
    trainer = trainer_class(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=build_compute_metrics_fn(data_args.task_name),
    )

    # Evaluation
    final_result = {
        'time': str(datetime.today()),
    }

    eval_results = {}
    if training_args.do_eval:
        logger.info("*** Validate ***")

        if data_args.autoregressive:
            #for test_dataset in test_datasets:
            #if training_args.dynamic_eval:
                #for test_dataset in test_datasets:
            eval_result = trainer.evaluate(eval_dataset=eval_dataset).compute()
            #else:
            #    eval_result = {'accuracy': trainer.evaluate(eval_dataset=eval_dataset, null_dataset=null_eval_dataset)}
            eval_results.update(eval_result)
            for key, value in eval_result.items():
                final_result[eval_dataset.args.task_name + '_dev_' + key] = value
            print (eval_results)    
        else:         
            eval_datasets = [eval_dataset]

            for eval_dataset in eval_datasets:
                trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name)
                output = trainer.evaluate(eval_dataset=eval_dataset)
                eval_result = output

                output_eval_file = os.path.join(
                    training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt"
                )

                with open(output_eval_file, "w") as writer:
                    logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
                    for key, value in eval_result.items():
                        logger.info("  %s = %s", key, value)
                        writer.write("%s = %s\n" % (key, value))
                        final_result[eval_dataset.args.task_name + '_dev_' + key] = value
                eval_results.update(eval_result)


    
    test_results = {}
    if training_args.do_predict:
        logging.info("*** Test ***")
        if data_args.autoregressive:
            
            #if training_args.dynamic_eval:
                #for test_dataset in test_datasets:
            test_result = trainer.evaluate(eval_dataset=test_dataset).compute()
            #else:
            #    test_result = {'accuracy': trainer.evaluate(eval_dataset=test_dataset, null_dataset=null_test_dataset)}
            test_results.update(test_result)
            for key, value in test_result.items():
                final_result[test_dataset.args.task_name + '_test_' + key] = value
                
            print (test_results)        
        else:
            test_datasets = [test_dataset]
            if data_args.task_name == "mnli":
                mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
                test_datasets.append(
                    FewShotDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test", use_demo=('demo' in model_args.few_shot_type))
                )

            for test_dataset in test_datasets:
                trainer.compute_metrics = build_compute_metrics_fn(test_dataset.args.task_name)
                output = trainer.evaluate(eval_dataset=test_dataset)
                test_result = output

                output_test_file = os.path.join(
                    training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt"
                )

                with open(output_test_file, "w") as writer:
                    logger.info("***** Test results {} *****".format(test_dataset.args.task_name))
                    for key, value in test_result.items():
                        logger.info("  %s = %s", key, value)
                        writer.write("%s = %s\n" % (key, value))
                        final_result[test_dataset.args.task_name + '_test_' + key] = value

                    if training_args.save_logit:
                        predictions = output.predictions
                        num_logits = predictions.shape[-1]
                        logits = predictions.reshape([test_dataset.num_sample, -1, num_logits]).mean(axis=0)
                        np.save(os.path.join(training_args.save_logit_dir, "{}-{}-{}.npy".format(test_dataset.task_name, training_args.model_id, training_args.array_id)), logits)

                test_results.update(test_result)
    
    if "large" in model_args.model_name_or_path:
        dir_name = 'large'
    elif 'shot' in model_args.model_name_or_path:
        dir_name = 'shot'
    elif 'disjoint' in model_args.model_name_or_path:
        dir_name = 'disjoint'
    elif 'base' in model_args.model_name_or_path:
        dir_name = 'base'
        
    model_name = model_args.model_name_or_path.split('/')[-1]
    with open(f"./logs/merged/{dir_name}/{model_name}.txt", 'a') as f:
    # with open(f"./logs/graft/random_1e-2.txt", 'a') as f:
    # with open("./logs/merged/old_1e-2.txt", 'a') as f:
        if data_args.task_name == 'mrpc' or data_args.task_name == 'qqp':
            f.write(str(data_args.task_name) + ',' + str(final_result[str(data_args.task_name)+'_test_eval_f1']) + '\n')
        elif data_args.task_name == 'mnli':
            f.write(str(data_args.task_name) + ',' + str(final_result['mnli_test_eval_mnli/acc']) + '\n')
        else:
            f.write(str(data_args.task_name) + ',' + str(final_result[str(data_args.task_name)+'_test_eval_acc']) + '\n')

    with FileLock(model_args.log_file_store + '.lock'):
        #'log_noembed_SGD_linearhead.lock'):
        with open(model_args.log_file_store, 'a') as f:
            final_result.update(vars(model_args))
            final_result.update(vars(training_args))
            final_result.update(vars(data_args))
            if 'evaluation_strategy' in final_result:
                final_result.pop('evaluation_strategy')
            f.write(str(final_result) + '\n')
    
    return eval_results

if __name__ == "__main__":
    main()