# Adapted from https://github.com/princeton-nlp/LM-BFF/blob/main/run.py

import dataclasses
import logging
import os
import sys
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=DeprecationWarning)
from dataclasses import dataclass, field
from typing import Callable, Dict, Optional
import torch
from scipy.special import softmax
import numpy as np

from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, EvalPrediction
from transformers import GlueDataTrainingArguments as DataTrainingArguments
from transformers import HfArgumentParser, TrainingArguments, set_seed

from cocolm.configuration_cocolm import COCOLMConfig
from cocolm.tokenization_cocolm import COCOLMTokenizer

from src.dataset import SuperGenDataset
from src.models import COCOLMForSequenceClassification, COCOLMForPromptFinetuning, RobertaForSequenceClassification, RobertaForPromptFinetuning 
from src.trainer import SuperGenTrainer
from src.processors import num_labels_mapping, output_modes_mapping, compute_metrics_mapping

from filelock import FileLock
from datetime import datetime


logger = logging.getLogger(__name__)


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )
    finetune_type: str = field(
        default='prompt',
        metadata={"help": "Fine-tuning type. Choice: finetune, prompt"}
    )
    smooth: Optional[float] = field(
        default=0,
        metadata={"help": "label smoothing"}
    )
    reg_weight: Optional[float] = field(
        default=10,
        metadata={"help": "Temporal ensemble regularization loss weight max"}
    )

@dataclass
class DynamicDataTrainingArguments(DataTrainingArguments):
    """
    Arguments for dynamic training.
    """

    task_name: str = field(
        default=None,
        metadata={"help": "Task name"}
    )
    # For prompting
    template: str = field(
        default=None,
        metadata={"help": "Template"}
    )
    mapping: str = field(
        default=None,
        metadata={"help": "Label word mapping"}
    )
    data_dir: str = field(
        default=None,
        metadata={"help": "Path to dataset"}
    )
    template_path: str = field(
        default=None,
        metadata={"help": "Path to a txt file that stores all the templates, one per line. Do not set this when prompt_path is used"}
    )
    mapping_path: str = field(
        default=None,
        metadata={"help": "Path to a txt file that stores all the label word mappings, one per line. Do not set this when prompt_path is used"}
    )
    prompt_path: str = field(
        default=None,
        metadata={"help": "Path to a txt file that stores all the prompts (templates and mappings), one per line"}
    )
    template_id: int = field(
        default=None,
        metadata={"help": "Template id if using template_path"}
    )
    mapping_id: int = field(
        default=None,
        metadata={"help": "Mapping id if using template_path"}
    )
    prompt_id: int = field(
        default=None,
        metadata={"help": "Prompt id if using prompt_path"}
    )
    top_n_template: int = field(
        default=None,
        metadata={"help": "Use top-n template in the template path"}
    )
    # For logging
    tag: str = field(
        default='',
        metadata={"help": "Set the tag and find the result easier in the log."}
    )
    debug_mode: bool = field(
        default=False,
        metadata={"help": "Debug mode"}
    )

    # For max length

    first_sent_limit: int = field(
        default=None,
        metadata={"help": "Limit the length of the first sentence (i.e., sent_0)"}
    )
    other_sent_limit: int = field(
        default=None,
        metadata={"help": "Limit the length of sentences other than the first sentence"}
    )
    use_full_length: bool = field(
        default=None,
        metadata={"help": "Use the full length (512)"}
    )
    truncate_head: bool = field(
        default=False,
        metadata={"help": "When exceeding the maximum length, truncate the head instead of the tail."}
    )
    # Do not set up the following fields. They are set up automatically.
    prompt: bool = field(
        default=False,
        metadata={"help": "Whether to use prompt-based fine-tuning"}
    )
    template_list: list = field(
        default=None,
        metadata={"help": "(DO NOT List of templates (only initialized after the program starts."}
    )

    eval_on_test: bool = field(
        default=False, metadata={"help": "Evaluate on test set instead of dev set"}
    )

    train_mode: str = field(
        default='train',
        metadata={"help": "train or gen-train."}
    )

@dataclass
class DynamicTrainingArguments(TrainingArguments):
    # For ensemble
    array_id: int = field(
        default=-1,
        metadata={"help": "Array ID (contains seed and hyper-paramter search) to idenfity the model"}
    )
    model_id: int = field(
        default=-1,
        metadata={"help": "Model ID (contains template information) to identify the model"}
    )
    # Regularization
    fix_layers: int = field(
        default=0,
        metadata={"help": "Fix bottom-n layers when optimizing"}
    )
    freeze_emb: bool = field(
        default=False,
        metadata={"help": "Fix embeddings"}
    )
    # Training
    save_at_last: bool = field(
        default=False,
        metadata={"help": "Instead of saving the best (dev performance) checkpoint, save the last checkpoint"}
    )
    # Turn off train/test
    evaluate_during_training: bool = field(
        default=False,
        metadata={"help": "Eval during train"}
    )
    no_train: bool = field(
        default=False,
        metadata={"help": "No training"}
    )
    no_predict: bool = field(
        default=False,
        metadata={"help": "No test"}
    )
    warmup_ratio: float = field(
        default=0,
        metadata={"help": "Warm up ratio"}
    )
    threshold: Optional[float] = field(
        default=0.8,
        metadata={"help": "Threshold for filtering out noisy samples"}
    )
    momentum: Optional[float] = field(
        default=0.8,
        metadata={"help": "Momentum parameter in temporal ensemble"}
    )
    temp_ensemble_rampup: Optional[int] = field(
        default=10,
        metadata={"help": "Number of intervals to ramp-up temporal ensemble regularization weight"}
    )

def main():
    parser = HfArgumentParser((ModelArguments, DynamicDataTrainingArguments, DynamicTrainingArguments))

    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    if 'prompt' in model_args.finetune_type:
        data_args.prompt = True

    if training_args.no_train:
        training_args.do_train = False
    if training_args.no_predict:
        training_args.do_predict = False

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )

    # Load prompt/template/mapping file
    if data_args.prompt:
        if data_args.prompt_path is not None:
            assert data_args.prompt_id is not None
            prompt_list = []
            with open(data_args.prompt_path) as f:
                for line in f:
                    line = line.strip()
                    template, mapping = line.split('\t')
                    prompt_list.append((template, mapping))

            data_args.template, data_args.mapping = prompt_list[data_args.prompt_id] 
            logger.info("Specify load the %d-th prompt: %s | %s" % (data_args.prompt_id, data_args.template, data_args.mapping))
        else:
            if data_args.template_path is not None:
                with open(data_args.template_path) as f:
                    data_args.template_list = []
                    for line in f:
                        line = line.strip()
                        if len(line) > 0:
                            data_args.template_list.append(line)

                # Load top-n templates
                if data_args.top_n_template is not None:
                    data_args.template_list = data_args.template_list[:data_args.top_n_template]
                logger.info("Load top-%d templates from %s" % (len(data_args.template_list), data_args.template_path))

                # ... or load i-th template
                if data_args.template_id is not None:
                    data_args.template = data_args.template_list[data_args.template_id]
                    data_args.template_list = None
                    logger.info("Specify load the %d-th template: %s" % (data_args.template_id, data_args.template))

            if data_args.mapping_path is not None:
                assert data_args.mapping_id is not None # Only can use one label word mapping
                with open(data_args.mapping_path) as f:
                    mapping_list = []
                    for line in f:
                        line = line.strip()
                        mapping_list.append(line)

                data_args.mapping = mapping_list[data_args.mapping_id]
                logger.info("Specify using the %d-th mapping: %s" % (data_args.mapping_id, data_args.mapping))

    # Check save path
    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(f"Output directory ({training_args.output_dir}) already exists.")

    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    try:
        num_labels = num_labels_mapping[data_args.task_name]
        output_mode = output_modes_mapping[data_args.task_name]
        logger.info("Task name: {}, number of labels: {}, output mode: {}".format(data_args.task_name, num_labels, output_mode))
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))
    
    special_tokens = []
    if "cocolm" in model_args.model_name_or_path:
        config = COCOLMConfig.from_pretrained(
            model_args.config_name if model_args.config_name else model_args.model_name_or_path,
            num_labels=num_labels,
            finetuning_task=data_args.task_name,
            cache_dir=model_args.cache_dir,
        )
        tokenizer = COCOLMTokenizer.from_pretrained(
            model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
            additional_special_tokens=special_tokens,
            cache_dir=model_args.cache_dir,
        )
    else:
        # Create config
        # config = AutoConfig.from_pretrained(
        #     model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        #     num_labels=num_labels,
        #     finetuning_task=data_args.task_name,
        #     cache_dir=model_args.cache_dir,
        # )
        config = COCOLMConfig.from_pretrained(
            model_args.config_name if model_args.config_name else model_args.model_name_or_path,
            num_labels=num_labels,
            finetuning_task=data_args.task_name,
            cache_dir=model_args.cache_dir,
        )
        # Create tokenizer
        # tokenizer = AutoTokenizer.from_pretrained(
        #     model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        #     additional_special_tokens=special_tokens,
        #     cache_dir=model_args.cache_dir,
        # )
        tokenizer = COCOLMTokenizer.from_pretrained(
            'microsoft/cocolm-large',
            additional_special_tokens=special_tokens,
            cache_dir=model_args.cache_dir,
        )

    if model_args.finetune_type == 'prompt':
        if config.model_type == 'roberta':
            model_fn = RobertaForPromptFinetuning
        elif config.model_type == 'cocolm':
            model_fn = COCOLMForPromptFinetuning
        else:
            raise NotImplementedError
    elif model_args.finetune_type == 'finetune':
        if config.model_type == 'roberta':
            model_fn = RobertaForSequenceClassification
        elif config.model_type == 'cocolm':
            model_fn = COCOLMForSequenceClassification
        else:
            raise NotImplementedError
    else:
        raise NotImplementedError    
    
    train_mode = data_args.train_mode
    train_dataset = (
        SuperGenDataset(data_args, tokenizer=tokenizer, mode=train_mode)
    )
    labels = np.array([feature.label for feature in train_dataset.features])
    truths = np.array([feature.true for feature in train_dataset.features])
    dev_dataset = (
        SuperGenDataset(data_args, tokenizer=tokenizer, mode="dev")
        if training_args.do_eval
        else None
    )
    test_dataset = (
        SuperGenDataset(data_args, tokenizer=tokenizer, mode="test")
        if training_args.do_predict
        else None
    )

    set_seed(training_args.seed)

    model = model_fn.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
    )

    # Pass dataset and argument information to the model
    if data_args.prompt:
        model.label_word_list = torch.tensor(train_dataset.label_word_list).long().cuda()
    model.model_args = model_args
    model.data_args = data_args
    model.tokenizer = tokenizer

    # Build metric
    def build_compute_metrics_fn(task_name: str) -> Callable[[EvalPrediction], Dict]:
        def compute_metrics_fn(p: EvalPrediction):
            # Note: the eval dataloader is sequential, so the examples are in order.
            # We average the logits over each sample for using demonstrations.
            predictions = p.predictions
            num_logits = predictions.shape[-1]
            logits = predictions.reshape([-1, num_logits])
            
            if num_logits == 1:
                preds = np.squeeze(logits)
            else:
                preds = np.argmax(logits, axis=1)

            # Just for sanity, assert label ids are the same.
            label_ids = p.label_ids.reshape([1, -1])
            label_ids_avg = label_ids.mean(axis=0)
            label_ids_avg = label_ids_avg.astype(p.label_ids.dtype)
            assert (label_ids_avg - label_ids[0]).mean() < 1e-2
            label_ids = label_ids[0]

            return compute_metrics_mapping[task_name](task_name, preds, label_ids)

        return compute_metrics_fn
    
    eval_dataset = test_dataset if data_args.eval_on_test else dev_dataset
    # Initialize our Trainer
    trainer = SuperGenTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=build_compute_metrics_fn(data_args.task_name)
    )

    # Training
    if training_args.do_train:
        trainer.train(model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None)
        if training_args.save_at_last:
            trainer.save_model(training_args.output_dir)
 
        if trainer.is_world_process_zero():
            torch.save(model_args, os.path.join(training_args.output_dir, "model_args.bin"))
            torch.save(data_args, os.path.join(training_args.output_dir, "data_args.bin"))
        
        # Reload the last checkpoint (for eval)
        model = model_fn.from_pretrained(training_args.output_dir)
        model = model.to(training_args.device)
        trainer.model = model
        if data_args.prompt:
            model.label_word_list = torch.tensor(train_dataset.label_word_list).long().cuda()
        model.model_args = model_args
        model.data_args = data_args
        model.tokenizer = tokenizer

    # Evaluation
    final_result = {
        'time': str(datetime.today()),
    }

    eval_results = {}
    if training_args.do_eval:
        logger.info("*** Validate ***")

        eval_datasets = [eval_dataset]

        for eval_dataset in eval_datasets:
            trainer.compute_metrics = build_compute_metrics_fn(eval_dataset.args.task_name)
            output = trainer.evaluate(eval_dataset=eval_dataset)
            eval_result = output.metrics 

            output_eval_file = os.path.join(
                training_args.output_dir, f"eval_results_{eval_dataset.args.task_name}.txt"
            )
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results {} *****".format(eval_dataset.args.task_name))
                for key, value in eval_result.items():
                    logger.info("  %s = %s", key, value)
                    writer.write("%s = %s\n" % (key, value))
                    final_result[eval_dataset.args.task_name + '_dev_' + key] = value
            eval_results.update(eval_result)

    test_results = {}
    if training_args.do_predict:
        logging.info("*** Test ***")
        test_datasets = [test_dataset]
        # test_datasets = [train_dataset]
        if data_args.task_name == "mnli":
            mnli_mm_data_args = dataclasses.replace(data_args, task_name="mnli-mm")
            test_datasets.append(
                SuperGenDataset(mnli_mm_data_args, tokenizer=tokenizer, mode="test")
            )

        for test_dataset in test_datasets:
            trainer.compute_metrics = build_compute_metrics_fn(test_dataset.args.task_name)
            output = trainer.evaluate(eval_dataset=test_dataset)
            test_result = output.metrics

            # predictions = output.predictions
            # pred_probs = softmax(predictions, axis=-1)
            # max_probs = np.amax(pred_probs, axis=-1)
            # predictions = np.argmax(pred_probs, axis=-1)
            # labels = np.array([feature.label for feature in train_dataset.features])
            # truths = np.array([feature.true for feature in train_dataset.features])
            # for threshold in [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]:
            #     select_pos = (labels == predictions) & (max_probs > threshold)
            #     prec = (labels[select_pos] == truths[select_pos]).sum() / select_pos.sum()
            #     rec = (labels[select_pos] == truths[select_pos]).sum() / (labels == truths).sum()
            #     print(f"thre: {threshold}; prec: {(labels[select_pos] == truths[select_pos]).sum()}/{select_pos.sum()} = {prec}; rec: {(labels[select_pos] == truths[select_pos]).sum()}/{(labels == truths).sum()} = {rec}")

            output_test_file = os.path.join(
                training_args.output_dir, f"test_results_{test_dataset.args.task_name}.txt"
            )
            with open(output_test_file, "w") as writer:
                logger.info("***** Test results {} *****".format(test_dataset.args.task_name))
                for key, value in test_result.items():
                    logger.info("  %s = %s", key, value)
                    writer.write("%s = %s\n" % (key, value))
                    final_result[test_dataset.args.task_name + '_test_' + key] = value

            test_results.update(test_result)

    with FileLock('log.lock'):
        with open('log', 'a') as f:
            final_result.update(vars(model_args))
            final_result.update(vars(training_args))
            final_result.update(vars(data_args))
            if 'evaluation_strategy' in final_result:
                final_result.pop('evaluation_strategy')
            f.write(str(final_result) + '\n')
    
    return test_results

if __name__ == "__main__":
    main()
