# Following code is adapted from:
# 1. The finetuning notebook from the QLoRA repository here: https://github.com/artidoro/qlora
# 2. Huggingface tutorial on training transformers for sequence classification here: https://huggingface.co/docs/transformers/tasks/sequence_classification

from dataset_tools import load_data
from utils import gpu_memory, print_trainable_parameters, fix_seeds
from modules.WTS_Trainer import WTS_Trainer_Naive, WTS_Trainer_FreeLB, WTS_Trainer_GaussianNoise

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForSequenceClassification
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model

from datasets import Dataset, DatasetDict
import numpy as np
import evaluate
import os
import json
import argparse
import wandb
from typing import List, Dict, Union, Any
import pandas as pd
import math
import warnings
warnings.filterwarnings("ignore")
from transformers import logging
logging.set_verbosity_error()


accuracy = evaluate.load("accuracy")

num_classes = {"DecodingTrust": 2, "BOSS": 3, "SST2": 2}

task2class = {
    "DecodingTrust": {"0": 0.4908, "1": 0.5092},
    "BOSS": {"0": 0.3333, "1": 0.3333, "2": 0.3333},
    "SST2": {"0": 0.5, "1": 0.5},
}


def data_preprocessing(data, text_field="sentence", label_field="label", csv=False):
    """
    Converts data for the given task into the following format:
    {
        'text': str: '...',
        'label': int: 0, 1, ...
    }

    Args:
    data: list of dictionaries {"sentence": "...", "label": "..."}
    task: str, task name

    Returns:
    data_list: list of dictionaries in the above format
    """

    data_list = []
    
    if csv:
        # data is a pandas dataframe with headers
        for i, row in data.iterrows():
            data_list.append({"text": str(row[text_field]), "label": int(row[label_field])})
    else:
        for item in data:
            data_list.append({"text": str(item[text_field]), "label": int(item[label_field])})

    return data_list


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)


def load_id_ood_data(data_src_path):
    if "SST2" in data_src_path:
        raw_data = load_data(data_src_path)
        in_distrib_data = data_preprocessing(raw_data["base"])
        ood_data_grouped = {ood_id: data_preprocessing(raw_data[ood_id]) for ood_id in raw_data.keys() if ood_id != "base"}
        
        return in_distrib_data, ood_data_grouped
    elif "DecodingTrust" in data_src_path:
        raw_data = load_data(data_src_path)
        in_distrib_data = data_preprocessing(raw_data["dev"]["base"]) + \
                        data_preprocessing(raw_data["train_demo"]["base_0"]) + \
                        data_preprocessing(raw_data["train_demo"]["base_1"]) + \
                        data_preprocessing(raw_data["train_demo"]["base_2"])
        ood_data_grouped = {ood_id: data_preprocessing(raw_data["dev"][ood_id]) for ood_id in raw_data["dev"].keys() if ood_id != "base"}
        
        return in_distrib_data, ood_data_grouped
    elif "BOSS" in data_src_path:
        assert "SentimentAnalysis" in data_src_path, "BOSS dataset is only available for SentimentAnalysis task"
        in_distrib_data = []
        ood_data_grouped = {}
        for dataset_name in ['amazon', 'dynasent', 'semeval', 'sst5']:
            dataset_dir = os.path.join(data_src_path, dataset_name)
            if dataset_name == 'amazon':
                train_tsv = os.path.join(dataset_dir, 'train.tsv')
                train_data = pd.read_csv(train_tsv, sep='\t')
                test_tsv = os.path.join(dataset_dir, 'test.tsv')
                test_data = pd.read_csv(test_tsv, sep='\t')
                
                in_distrib_data += data_preprocessing(train_data, text_field="Text", label_field="Label", csv=True) + data_preprocessing(test_data, text_field="Text", label_field="Label", csv=True)
            else:
                test_tsv = os.path.join(dataset_dir, 'test.tsv')
                test_data = pd.read_csv(test_tsv, sep='\t')
                
                ood_data_grouped[dataset_name] = data_preprocessing(test_data, text_field="Text", label_field="Label", csv=True)
        
        # post-process in_dist data to make classes balanced
        label2data = {}
        for item in in_distrib_data:
            label = item["label"]
            if label not in label2data:
                label2data[label] = []
            label2data[label].append(item)
        min_len = min([len(label2data[label]) for label in label2data.keys()])
        for label in label2data.keys():
            label2data[label] = label2data[label][:min_len]
            
        in_distrib_data = []
        for label in label2data.keys():
            in_distrib_data += label2data[label]
            
        np.random.shuffle(in_distrib_data)
        
        return in_distrib_data, ood_data_grouped
    else:
        raise NotImplementedError(f"Data source path {data_src_path} not implemented")
                

def main(args):
    model_id = args.model_id
    results_file = args.results_file
    task = args.task
    ft_mode = args.ft_mode
    weak_labels_file = args.weak_labels_file
    enable_aux_loss = args.enable_aux_loss
    alpha_max = args.alpha_max
    burn_in_period = args.burn_in_period
    tag = args.tag
    lr = args.lr
    num_epochs = args.n_epochs
    max_len = args.max_len
    batch_size = args.batch_size

    # Print experiment details
    print("\n* * * * * Experiment Details * * * * *")
    print("Model ID:", model_id)
    print("Results file:", results_file)
    print("Task:", task)
    print("Finetuning mode:", ft_mode)
    if ft_mode == "wts":
        print("Weak labels file:", weak_labels_file)
    print("Number of epochs:", num_epochs)
    print("Number of available GPUs:", torch.cuda.device_count())
    print("GPU names:", [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])
    print("Batch size:", batch_size)
    print("Learning rate:", lr)
    if args.lora:
        print("Using LoRA. Are you sure about this?")
    print("* * * * * * * * * * * * * * * * * * * *\n", flush=True)

    in_distrib_data, ood_data_grouped = load_id_ood_data(args.data_src_path)

    # Load data
    if ft_mode == "wts":
        assert weak_labels_file is not None, "Weak labels file must be provided for weak-to-strong finetuning"
        assert os.path.exists(weak_labels_file), f"Weak labels file does not exist: {weak_labels_file}"
        weak_labels_data = load_data(weak_labels_file)[args.task]
        print("==> Number of weak labels:", len(weak_labels_data))
        train_val_heldout_data = weak_labels_data
    else:
        train_val_heldout_data = in_distrib_data
        
    # Split data into training and validation datasets
    np.random.shuffle(train_val_heldout_data)
    
    cutoff_idx = int(0.5 * len(train_val_heldout_data))
    train_val_data, heldout_data = train_val_heldout_data[:cutoff_idx], train_val_heldout_data[cutoff_idx:]
        
    cutoff_idx = int(0.95 * len(train_val_data))
    train_data = train_val_data[:cutoff_idx]
    val_data = train_val_data[cutoff_idx:]
    
    print("==> Number of training samples:", len(train_data))
    print("==> Number of validation samples:", len(val_data))
    print("==> Number of heldout samples:", len(heldout_data))
    print(f"==> Data size of each class:")
    print(pd.Series([item["label"] for item in train_data]).value_counts())
    
    test_data_grouped = {"In-Dist.": val_data}
    test_data_grouped.update(ood_data_grouped)

    output_dir = f"{args.models_save_root}/{model_id}/{task}---{tag}"

    if args.lora:
        # BnB configuration for 4-bit quantization
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16
        )
    else:
        bnb_config = None

    # Load model and tokenizer
    print("==> Loading model and tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForSequenceClassification.from_pretrained(
        model_id,
        device_map="auto",  # device_map={"":0},
        num_labels=num_classes[task],
        # quantization_config=bnb_config  #? not sure
    )  
    
    if args.lora:
        model = prepare_model_for_kbit_training(model)
        lora_config = LoraConfig(
            r=16, 
            lora_alpha=32,
            target_modules='all-linear',
            # target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'],     # ["query_key_value"], 
            lora_dropout=0.05, 
            bias="none", 
            task_type='SEQ_CLS'     # "CAUSAL_LM"
        )
        model = get_peft_model(model, lora_config)
        print_trainable_parameters(model)
    
    print(f"==> GPU memory used: {gpu_memory()[0]:.2f} / {gpu_memory()[1]:.2f} GB")
    # Create Huggingface datasets
    train_data_grouped = {"text": [], "label": []}
    for item in train_data:
        text = item["text"]
        assert isinstance(text, str) and len(text) > 0, f"Invalid text {text}"
        label = item["label"]
        assert label in [0, 1, 2], f"Invalid label {label}"
        train_data_grouped["text"].append(text)
        train_data_grouped["label"].append(label)
    val_data_grouped = {"text": [], "label": []}
    for item in val_data:
        text = item["text"]
        assert isinstance(text, str) and len(text) > 0, f"Invalid text {text}"
        label = item["label"]
        assert label in [0, 1, 2], f"Invalid label {label}"
        val_data_grouped["text"].append(text)
        val_data_grouped["label"].append(label)
    ds = DatasetDict(
        {
            "train": Dataset.from_dict(train_data_grouped),
            "val": Dataset.from_dict(val_data_grouped),
        }
    )

    # Set padding token to eos token
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id

    # Tokenize the sentences in the datasets
    print("==> Tokenizing dataset...")
    tokenized_ds = ds.map(lambda samples: tokenizer(samples["text"], padding=True, truncation=True, max_length=max_len), batched=True)

    # Fine-tune the model
    print(f"\n==> Fine-tuning model {model_id}...\n")
    total_steps = math.ceil(len(train_data) / batch_size) * num_epochs
    
    training_args = TrainingArguments(
        report_to='wandb' if args.enable_wandb else "none",
        run_name=args.exp_name,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        gradient_accumulation_steps=1,
        num_train_epochs=num_epochs,
        eval_strategy="steps",
        eval_steps=total_steps // 10,
        logging_strategy="steps",
        logging_steps=1,
        save_strategy="no",
        # save_steps=total_steps // 2,
        # save_total_limit=1,  # Only save the best model
        # load_best_model_at_end=True,
        warmup_ratio=0.2,
        learning_rate=lr,
        output_dir=output_dir,
    )
    
    base_kwargs = {
        "model": model,
        "train_dataset": tokenized_ds["train"],
        "eval_dataset": tokenized_ds["val"],
        "args": training_args,
        "data_collator": DataCollatorWithPadding(tokenizer),
        "tokenizer": tokenizer,
        "compute_metrics": compute_metrics,
    }
    wts_kwargs = {
        "enable_aux_loss": enable_aux_loss,
        "task": task,
        "alpha_max": alpha_max,
        "burn_in_period": burn_in_period,
        "task2class": task2class,
    }
    
    if args.trainer_class == "base":
        print("==> Using base trainer")
        trainer = WTS_Trainer_Naive(
            **base_kwargs,
            **wts_kwargs,
        )
    elif args.trainer_class == "gauss":
        print("==> Using GaussianNoise trainer")
        gauss_kwargs = {
            "noise_std": args.noise_std,
        }
        trainer = WTS_Trainer_GaussianNoise(
            **base_kwargs,
            **wts_kwargs,
            **gauss_kwargs,
        )
    elif args.trainer_class == "freelb":
        print("==> Using FreeLB trainer")
        freelb_kwargs = {
            "adv_K": args.adv_K,
            "adv_lr": args.adv_lr,
            "adv_init_mag": args.adv_init_mag,
            "adv_max_norm": args.adv_max_norm,
            "adv_norm_type": args.adv_norm_type,
        }
        trainer = WTS_Trainer_FreeLB(
            **base_kwargs,
            **wts_kwargs,
            **freelb_kwargs,
        )
    else:
        raise NotImplementedError(f"Trainer class {args.trainer_class} not implemented")
    
    model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
    trainer.train()

    # Create results directory if it doesn't exist
    if not os.path.exists(os.path.dirname(results_file)):
        os.makedirs(os.path.dirname(results_file))

    # Load the results file
    try:
        with open(results_file, "r") as f:
            results = json.load(f)
    except:
        results = {}
        
    print("\n==> Evaluating the model...\n")
    for test_split_name, test_split_data in test_data_grouped.items():
        test_ds = Dataset.from_dict(
            {"text": [item["text"] for item in test_split_data], "label": [item["label"] for item in test_split_data]}
        )
        tokenized_test_ds = test_ds.map(
            lambda samples: tokenizer(samples["text"], padding=True, truncation=True, max_length=max_len), batched=True
        )

        # Evaluate the model
        eval_results = trainer.evaluate(tokenized_test_ds)

        # Add results to the dictionary
        if test_split_name not in results.keys():
            results[test_split_name] = {}

        split_results = results[test_split_name]

        if ft_mode == "weak":
            split_results["Weak Performance"] = eval_results["eval_accuracy"] * 100
        elif ft_mode == "strong":
            split_results["Strong Performance"] = eval_results["eval_accuracy"] * 100
        elif ft_mode == "wts":
            wts_name = args.trainer_class
            if args.enable_aux_loss:
                wts_name += " + aux"
            split_results[f"WTS ({wts_name}) Performance"] = eval_results["eval_accuracy"] * 100

        results[test_split_name] = split_results
        
    def avg(lst):
        return sum(lst) / len(lst)
    
    # Compute overall OOD performance
    ft_mode2metric_list = {}
    for test_split_name, split_results in results.items():
        if test_split_name == "base":
            continue
        for mode, metric in split_results.items():
            if mode not in ft_mode2metric_list:
                ft_mode2metric_list[mode] = []
            ft_mode2metric_list[mode].append(metric)
    ft_mode2avg_metric = {mode: avg(metric) for mode, metric in ft_mode2metric_list.items()}
    results["OOD avg."] = ft_mode2avg_metric
    
    print(f"==> Results of {model_id}: \n{results}\n")

    # Save the results
    with open(results_file, "w") as f:
        json.dump(results, f, indent=4)

    print("==> Results saved to:", results_file)

    # Get predictions on holdout data
    if ft_mode == "weak" and not args.disable_weak_labels_generation:
        assert weak_labels_file is not None
        if not os.path.exists(os.path.dirname(weak_labels_file)):
            os.makedirs(os.path.dirname(weak_labels_file))
        
        data_holdout = heldout_data
        ds_holdout = Dataset.from_dict(
            {"text": [item["text"] for item in data_holdout], "label": [item["label"] for item in data_holdout]}
        )
        tokenized_ds_holdout = ds_holdout.map(
            lambda samples: tokenizer(samples["text"], truncation=True, max_length=max_len), batched=True
        )

        print("\n==> Generating weak labels on holdout data...\n")
        predictions = trainer.predict(tokenized_ds_holdout)

        # Get labels from predictions
        preds = np.argmax(predictions.predictions, axis=1)
        # print("Predictions:", preds)

        weak_labels_dict_list = []
        for i, item in enumerate(data_holdout):
            weak_labels_dict_list.append({"text": item["text"], "label": int(preds[i])})

        # Load the weak labels file
        try:
            with open(weak_labels_file, "r") as f:
                weak_labels = json.load(f)
        except:
            weak_labels = {}

        weak_labels[task] = weak_labels_dict_list

        # Save the weak labels
        with open(weak_labels_file, "w") as f:
            json.dump(weak_labels, f, indent=2)

        print("==> Weak labels saved to:", weak_labels_file)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    
    # Basics
    parser.add_argument("--model_id", type=str, required=True, help="Model ID")
    parser.add_argument("--data_src_path", type=str, required=True, help="Data source path")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--results_dir", type=str, default="results", help="Results file save directory")
    parser.add_argument("--task", type=str, required=True, help="Task")
    parser.add_argument("--tag", type=str, default="", help="Tag for the experiment")
    parser.add_argument("--models_save_root", type=str, default="models", help="Root directory for saving models")
    
    # Weak-to-Strong Settings
    parser.add_argument(
        "--ft_mode",
        type=str,
        choices=["weak", "strong", "wts",],
        default="weak",
        help="Finetuning mode",
    )
    parser.add_argument(
        "--weak_labels_file",
        type=str,
        default=None,
        help="Weak labels file for weak-to-strong finetuning",
    )
    parser.add_argument("--disable_weak_labels_generation", action="store_true", help="Disable weak labels generation")
    parser.add_argument("--exp_info", default="default", type=str)
    
    parser.add_argument("--n_epochs", type=float, default=2.0, help="Number of epochs")
    parser.add_argument("--lr", type=float, default=None, help="Learning rate")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--lora", type=int, default=0, choices=[0, 1], help="Use LoRA")
    
    # WandB
    parser.add_argument("--enable_wandb", type=int, default=0, choices=[0, 1], help="Use WandB")
    
    # OOD Trainer
    parser.add_argument("--trainer_class", type=str, default="base", choices=["base", "freelb", "gauss"])
    ## Auxiliary loss
    parser.add_argument("--enable_aux_loss", type=int, default=0, choices=[0, 1], help="Enable auxiliary loss")
    parser.add_argument(
        "--alpha_max", type=float, default=0.25, help="Maximum value of alpha for the auxiliary confidence loss"
    )
    parser.add_argument("--burn_in_period", type=float, default=0.2)
    ## FreeLB
    parser.add_argument("--adv_K", type=int, default=2)
    parser.add_argument("--adv_lr", type=float, default=0.5)
    parser.add_argument("--adv_init_mag", type=float, default=0.6)
    parser.add_argument("--adv_max_norm", type=float, default=0.0)
    parser.add_argument("--adv_norm_type", type=str, default="l2")
    ## GaussianNoise
    parser.add_argument("--noise_std", type=float, default=0.002)
    
    # Model
    parser.add_argument("--max_len", type=int, default=1024)
    
    args = parser.parse_args()
    
    args.exp_name = f"{args.exp_info}--{args.tag}--{args.seed}"
    os.makedirs(args.results_dir, exist_ok=True)
    args.results_file = os.path.join(args.results_dir, f"{args.exp_name}.json")
    
    if args.disable_weak_labels_generation:
        assert args.ft_mode == "weak"
    
    wandb.init()
    
    fix_seeds(args.seed)
    
    main(args)
    