# Following code is adapted from:
# 1. The finetuning notebook from the QLoRA repository here: https://github.com/artidoro/qlora
# 2. Huggingface tutorial on training transformers for sequence classification here: https://huggingface.co/docs/transformers/tasks/sequence_classification

import torch
import transformers
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForSequenceClassification      # AutoModelForCausalLM
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from datasets import Dataset, DatasetDict
import numpy as np
import evaluate
import os
import json
import argparse

from dataset_tools import load_data

accuracy = evaluate.load("accuracy")

num_classes = {
    'sst2': 2,
    'qqp': 2,
    'mnli': 3,
    'mnli-mm': 3,
    'qnli': 2,
    'rte': 2
}

class_wts = {
    "sst2": {
        "0": 0.4155,
        "1": 0.5845
    },
    "qqp": {
        "0": 0.7038,
        "1": 0.2962
    },
    "mnli": {
        "0": 0.3512,
        "1": 0.3446,
        "2": 0.3042
    },
    "mnli-mm": {
        "0": 0.398,
        "1": 0.2741,
        "2": 0.3279
    },
    "qnli": {
        "0": 0.407,
        "1": 0.593
    },
    "rte": {
        "0": 0.4046,
        "1": 0.5954
    }
}

# Custom trainer for weak-to-strong finetuning with auxiliary loss
class WTS_Trainer(transformers.Trainer):
    def __init__(self, *args, ft_mode=None, task=None, alpha_max=0.5, **kwargs):
        super().__init__(*args, **kwargs)

        self.ft_mode = ft_mode
        self.task = task
        self.alpha_max = alpha_max

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss = torch.nn.functional.cross_entropy(logits, labels)

        # Auxilary confidence loss
        if self.ft_mode == 'wts-aux-loss':
            # Warmup alpha from 0 to alpha_max over the first 20 percent of the training
            alpha = self.alpha_max * min(1.0, self.state.global_step / (0.2 * self.state.max_steps))
            print(f"Alpha: {alpha}")

            # Get labels from logits by thresholding based on class weights
            class_weights = class_wts[self.task]
            num_instances = logits.size(0)
            logit_labels = [None] * num_instances
            remaining_fraction = 1.0

            for label_str in class_weights:
                class_fraction = min(class_weights[label_str] / remaining_fraction, 1.0)
                label = int(label_str)
                
                # Assign label to instances with logits above threshold
                remaining_logits = np.array([logits[i, label].item() for i in range(num_instances) if logit_labels[i] is None])
                threshold = np.quantile(remaining_logits, 1.0 - class_fraction)
                for i in range(num_instances):
                    if logit_labels[i] is None and logits[i, label] > threshold:
                        logit_labels[i] = label

                remaining_fraction -= class_weights[label_str]

            # Assign last label to remaining logit_labels
            for i in range(num_instances):
                if logit_labels[i] is None:
                    logit_labels[i] = label

            # Print fraction of each label
            # for label_str in class_weights:
            #     label = int(label_str)
            #     print(f"Fraction of label {label} in logit labels: {logit_labels.count(label) / num_instances}")

            # input("Press Enter to continue...")

            # Compute auxiliary loss
            logit_labels = torch.tensor(logit_labels, device=logits.device)
            aux_loss = torch.nn.functional.cross_entropy(logits, logit_labels)
            loss = ((1 - alpha) * loss) + (alpha * aux_loss)

        return (loss, outputs) if return_outputs else loss

def data_preprocessing(data, task):
    """
    Converts data for the given task into the following format:
    {
        'text': '...',
        'label': 0, 1, ...
    }

    Args:
    data: list of dictionaries
    task: str, task name

    Returns:
    data_list: list of dictionaries in the above format
    """

    data_list = []
    for item in data:
        if task == 'sst2':
            data_list.append({
                'text': item['sentence'],
                'label': item['label']
            })
        elif task == 'qqp':
            data_list.append({
                'text': f"Question 1: {item['question1']}\nQuestion 2: {item['question2']}",
                'label': item['label']
            })
        elif task == 'mnli' or task == 'mnli-mm':
            data_list.append({
                'text': f"Premise: {item['premise']}\nHypothesis: {item['hypothesis']}",
                'label': item['label']
            })
        elif task == 'qnli':
            data_list.append({
                'text': f"Question: {item['question']}\nSentence: {item['sentence']}",
                'label': item['label']
            })
        elif task == 'rte':
            data_list.append({
                'text': f"Sentence 1: {item['sentence1']}\nSentence 2: {item['sentence2']}",
                'label': item['label']
            })
        else:
            raise ValueError(f"Task {task} not supported")

    return data_list

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

if __name__ == "__main__":

    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_id", type=str, default="EleutherAI/pythia-14m", help="Model ID")
    parser.add_argument("--results_file", type=str, default="results/pythia-results.json", help="Results file")
    parser.add_argument("--task", type=str, default="sst2", help="Task")
    parser.add_argument("--ft_mode", type=str, choices=['weak', 'strong', 'wts-naive', 'wts-aux-loss'], default="weak", help="Finetuning mode")
    parser.add_argument("--weak_labels_file", type=str, default="weak_labels/pythia-14m.json", help="Weak labels file for weak-to-strong finetuning")
    parser.add_argument("--num_epochs", type=float, default=2.0, help="Number of epochs")
    parser.add_argument("--tag", type=str, default="", help="Tag for the experiment")
    args = parser.parse_args()

    model_id = args.model_id
    results_file = args.results_file
    task = args.task
    ft_mode = args.ft_mode
    weak_labels_file = args.weak_labels_file
    num_epochs = args.num_epochs
    tag = args.tag

    max_len = 1024

    # Print experiment details
    print("\n* * * * * Experiment Details * * * * *")
    print("Model ID:", model_id)
    print("Results file:", results_file)
    print("Task:", task)
    print("Finetuning mode:", ft_mode)
    if ft_mode == 'wts-naive' or ft_mode == 'wts-aux-loss':
        print("Weak labels file:", weak_labels_file)
    print("Number of epochs:", num_epochs)
    print("GPU:", torch.cuda.get_device_name(0))
    print("* * * * * * * * * * * * * * * * * * * *\n", flush=True)

    # Load data
    if ft_mode == 'wts-naive' or ft_mode == 'wts-aux-loss':
        data = load_data(weak_labels_file)[task]

    else:
        data = load_data('data/advglue_train.json')[task]
        data = data_preprocessing(data, task)

    output_dir = f"models/{model_id}-{task}{tag}"

    # BnB configuration for 4-bit quantization
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForSequenceClassification.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0}, num_labels=num_classes[task])

    model.gradient_checkpointing_enable()
    model = prepare_model_for_kbit_training(model)

    # print(model)
    # exit()

    # LoRA configuration for memory-efficient training
    config = LoraConfig(
        r=16, 
        lora_alpha=32,
        target_modules='all-linear',
        # target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'],     # ["query_key_value"], 
        lora_dropout=0.05, 
        bias="none", 
        task_type='SEQ_CLS'     # "CAUSAL_LM"
    )

    # Insert LoRA into the model
    model = get_peft_model(model, config)
    # print_trainable_parameters(model)

    # Split into training and validation datasets
    data_len = len(data)
    train_len = int(0.8 * data_len)
    np.random.shuffle(data)

    train_data = data[:train_len]
    val_data = data[train_len:]

    # Create Huggingface datasets
    ds = DatasetDict({
        'train': Dataset.from_dict({'text': [item['text'] for item in train_data], 'label': [item['label'] for item in train_data]}),
        'val': Dataset.from_dict({'text': [item['text'] for item in val_data], 'label': [item['label'] for item in val_data]})
    })

    # Set padding token to eos token
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id

    # Tokenize the sentences in the datasets
    tokenized_ds = ds.map(lambda samples: tokenizer(samples['text'], truncation=True, max_length=max_len), batched=True)

    batch_size = int(0.05 * train_len)
    # print("Batch size:", batch_size)

    # Fine-tune the model
    print("\nFine-tuning the model...\n")
    # trainer = transformers.Trainer(
    trainer = WTS_Trainer(
        model=model,
        train_dataset=tokenized_ds['train'],      # data["train"],
        eval_dataset=tokenized_ds['val'],
        args=transformers.TrainingArguments(
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            # gradient_accumulation_steps=4,
            num_train_epochs=num_epochs,
            evaluation_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
            warmup_steps=2,
            # max_steps=10,
            learning_rate=2e-4,
            fp16=True,
            logging_steps=1,
            output_dir=output_dir,
            optim="paged_adamw_8bit"
        ),
        # data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
        data_collator=transformers.DataCollatorWithPadding(tokenizer),
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
        ft_mode=ft_mode,
        task=task,
        alpha_max=0.5       # Maximum value of alpha for the auxiliary confidence loss
    )
    model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
    trainer.train()

    # Load test data
    test_data = load_data('data/advglue_test.json')[task]
    test_data = data_preprocessing(test_data, task)
    test_ds = Dataset.from_dict({'text': [item['text'] for item in test_data], 'label': [item['label'] for item in test_data]})
    tokenized_test_ds = test_ds.map(lambda samples: tokenizer(samples['text'], truncation=True, max_length=max_len), batched=True)

    # Evaluate the model
    print("\nEvaluating the model...\n")
    eval_results = trainer.evaluate(tokenized_test_ds)
    print(eval_results)

    # Create results directory if it doesn't exist
    if not os.path.exists(os.path.dirname(results_file)):
        os.makedirs(os.path.dirname(results_file))

    # Load the results file
    try:
        with open(results_file, 'r') as f:
            results = json.load(f)
    except:
        results = {}

    # Add task results to the dictionary
    if task not in results:
        results[task] = {}

    task_results = results[task]

    if ft_mode == 'weak':
        task_results['Weak Performance'] = eval_results['eval_accuracy'] * 100
    elif ft_mode == 'strong':
        task_results['Strong Performance'] = eval_results['eval_accuracy'] * 100
    elif ft_mode == 'wts-naive':
        task_results['WTS-Naive'] = eval_results['eval_accuracy'] * 100
    elif ft_mode == 'wts-aux-loss':
        task_results['WTS-Aux-Loss'] = eval_results['eval_accuracy'] * 100

    results[task] = task_results

    # Save the results
    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)

    print("Results saved to:", results_file)

    # Get predictions on holdout data
    if ft_mode == 'weak':
        data_holdout = load_data('data/advglue_holdout.json')[task]
        data_holdout = data_preprocessing(data_holdout, task)
        ds_holdout = Dataset.from_dict({'text': [item['text'] for item in data_holdout], 'label': [item['label'] for item in data_holdout]})
        tokenized_ds_holdout = ds_holdout.map(lambda samples: tokenizer(samples['text'], truncation=True, max_length=max_len), batched=True)

        print("\nGenerating weak labels on holdout data...\n")
        predictions = trainer.predict(tokenized_ds_holdout)
        
        # Get labels from predictions
        preds = np.argmax(predictions.predictions, axis=1)
        # print("Predictions:", preds)

        weak_labels_dict_list = []
        for i, item in enumerate(data_holdout):
            weak_labels_dict_list.append({
                'text': item['text'],
                'label': int(preds[i])
            })

        # Save weak labels to file
        weak_labels_file = f'weak_labels/{model_id}{tag}.json'

        # Create directory if it doesn't exist
        if not os.path.exists(os.path.dirname(weak_labels_file)):
            os.makedirs(os.path.dirname(weak_labels_file))

        # Load the weak labels file
        try:
            with open(weak_labels_file, 'r') as f:
                weak_labels = json.load(f)
        except:
            weak_labels = {}

        weak_labels[task] = weak_labels_dict_list

        # Save the weak labels
        with open(weak_labels_file, 'w') as f:
            json.dump(weak_labels, f, indent=2)

        print("Weak labels saved to:", weak_labels_file)
