# Following code is adapted from:
# 1. The finetuning notebook from the QLoRA repository here: https://github.com/artidoro/qlora
# 2. Huggingface tutorial on training transformers for sequence classification here: https://huggingface.co/docs/transformers/tasks/sequence_classification

import torch
import transformers
from transformers import PreTrainedTokenizerBase
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer, AutoModelForSequenceClassification      # AutoModelForCausalLM, BitsAndBytesConfig
# from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from datasets import Dataset, DatasetDict
import numpy as np
import evaluate
import os
import json
import argparse

from dataset_tools import load_data

accuracy = evaluate.load("accuracy")

num_classes = {
    'sst2': 2,
    'qqp': 2,
    'mnli': 3,
    'mnli-mm': 3,
    'qnli': 2,
    'rte': 2
}

# Custom trainer for weak-to-strong finetuning with auxiliary loss
class WTS_Trainer(transformers.Trainer):
    def __init__(self, *args, ft_mode=None, alpha_max=0.5, lambda_coeff=0.5, warm_up=0.2, class_wts=None, use_samples="adversarial", **kwargs):
        super().__init__(*args, **kwargs)

        self.ft_mode = ft_mode
        self.alpha_max = alpha_max
        self.class_wts = class_wts
        self.lambda_coeff = lambda_coeff
        self.warm_up = warm_up
        self.mode = None    # 'evaluate' or 'predict'
        self.use_samples = use_samples  # whether to use clean or adversarial samples for evaluation or prediction

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs_adv = model(**inputs['adv_tokenized'])
        logits_adv = outputs_adv.logits
        outputs_orig = model(**inputs['orig_tokenized'])
        logits_orig = outputs_orig.logits

        # loss = torch.nn.functional.cross_entropy(logits, labels)
        loss_adv = torch.nn.functional.cross_entropy(logits_adv, labels)
        loss_orig = torch.nn.functional.cross_entropy(logits_orig, labels)
        loss = self.lambda_coeff * loss_adv + (1 - self.lambda_coeff) * loss_orig

        # Auxilary confidence loss
        if self.ft_mode == 'wts-aux-loss':
            # Warmup alpha from 0 to alpha_max over the first 20 percent of the training
            if self.warm_up == 0:
                alpha = self.alpha_max
            else:
                alpha = self.alpha_max * min(1.0, self.state.global_step / (self.warm_up * self.state.max_steps))
            # print(f"Alpha: {alpha}")

            # Get labels from logits by thresholding based on class weights
            num_instances = logits_orig.size(0)
            logit_labels = [None] * num_instances
            remaining_fraction = 1.0

            for label_str in self.class_wts:
                class_fraction = min(self.class_wts[label_str] / remaining_fraction, 1.0)
                label = int(label_str)
                
                # Assign label to instances with logits above threshold
                remaining_logits = np.array([logits_orig[i, label].item() for i in range(num_instances) if logit_labels[i] is None])
                threshold = np.quantile(remaining_logits, 1.0 - class_fraction)
                for i in range(num_instances):
                    if logit_labels[i] is None and logits_orig[i, label] > threshold:
                        logit_labels[i] = label

                remaining_fraction -= self.class_wts[label_str]

            # Assign last label to remaining logit_labels
            for i in range(num_instances):
                if logit_labels[i] is None:
                    logit_labels[i] = label

            # input("Press Enter to continue...")

            # Compute auxiliary loss
            logit_labels = torch.tensor(logit_labels, device=logits_orig.device)
            # aux_loss = torch.nn.functional.cross_entropy(logits, logit_labels)
            aux_loss_adv = torch.nn.functional.cross_entropy(logits_adv, logit_labels)
            aux_loss_orig = torch.nn.functional.cross_entropy(logits_orig, logit_labels)
            aux_loss = self.lambda_coeff * aux_loss_adv + (1 - self.lambda_coeff) * aux_loss_orig
            loss = ((1 - alpha) * loss) + (alpha * aux_loss)

        if return_outputs:
            # print(f"Mode: {self.mode}")
            if self.mode == 'evaluate' and self.use_samples == 'adversarial':
                return (loss, outputs_adv)
            else:
                return (loss, outputs_orig)
        else:
            return loss
        # return (loss, outputs_orig) if return_outputs else loss
    
    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"):
        self.mode = 'evaluate'
        print(f"Mode: {self.mode}, Use samples: {self.use_samples}")
        return super().evaluate(eval_dataset, ignore_keys, metric_key_prefix)

    def predict(self, test_dataset, ignore_keys=None, metric_key_prefix: str = "eval"):
        self.mode = 'predict'
        print(f"Mode: {self.mode}, Use samples: {self.use_samples}")
        return super().predict(test_dataset, ignore_keys, metric_key_prefix)

def data_preprocessing(data, task):
    """
    Converts data for the given task into the following format:
    {
        'adv_text': '...',
        'orig_text': '...',
        'label': 0, 1, ...
    }

    Args:
    data: list of dictionaries
    task: str, task name

    Returns:
    data_list: list of dictionaries in the above format
    """

    data_list = []
    for item in data:
        if task == 'sst2':
            data_list.append({
                'adv_text': item['sentence'],
                'orig_text': item['original_sentence'],
                'label': item['label']
            })
        elif task == 'qqp':
            if 'original_question1' in item:
                orig_text = f"Question 1: {item['original_question1']}\nQuestion 2: {item['question2']}"
            elif 'original_question2' in item:
                orig_text = f"Question 1: {item['question1']}\nQuestion 2: {item['original_question2']}"
            else:
                raise ValueError("Original question not found")

            data_list.append({
                'adv_text': f"Question 1: {item['question1']}\nQuestion 2: {item['question2']}",
                'orig_text': orig_text,
                'label': item['label']
            })
        elif task == 'mnli' or task == 'mnli-mm':
            if 'original_premise' in item:
                orig_text = f"Premise: {item['original_premise']}\nHypothesis: {item['hypothesis']}"
            elif 'original_hypothesis' in item:
                orig_text = f"Premise: {item['premise']}\nHypothesis: {item['original_hypothesis']}"
            else:
                raise ValueError("Original premise or hypothesis not found")

            data_list.append({
                'adv_text': f"Premise: {item['premise']}\nHypothesis: {item['hypothesis']}",
                'orig_text': orig_text,
                'label': item['label']
            })
        elif task == 'qnli':
            if 'original_question' in item:
                orig_text = f"Question: {item['original_question']}\nSentence: {item['sentence']}"
            elif 'original_sentence' in item:
                orig_text = f"Question: {item['question']}\nSentence: {item['original_sentence']}"
            else:
                raise ValueError("Original question or sentence not found")
            
            data_list.append({
                'adv_text': f"Question: {item['question']}\nSentence: {item['sentence']}",
                'orig_text': orig_text,
                'label': item['label']
            })

        elif task == 'rte':
            if 'original_sentence1' in item:
                orig_text = f"Sentence 1: {item['original_sentence1']}\nSentence 2: {item['sentence2']}"
            elif 'original_sentence2' in item:
                orig_text = f"Sentence 1: {item['sentence1']}\nSentence 2: {item['original_sentence2']}"
            else:
                raise ValueError("Original sentence not found")
            
            data_list.append({
                'adv_text': f"Sentence 1: {item['sentence1']}\nSentence 2: {item['sentence2']}",
                'orig_text': orig_text,
                'label': item['label']
            })
        else:
            raise ValueError(f"Task {task} not supported")

    return data_list

def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

def tokenize_adv_orig_text(tokenizer: PreTrainedTokenizerBase, samples):
    """
    Tokenizes the adversarial and original texts in the dataset.

    Args:
    tokenizer: PreTrainedTokenizerBase, tokenizer
    samples: samples from the dataset

    Returns:
    dictionary of tokenized adversarial and original texts
    """

    adv_tokenized = tokenizer(samples['adv_text'], truncation=True, max_length=max_len)
    orig_tokenized = tokenizer(samples['orig_text'], truncation=True, max_length=max_len)

    return {
        'adv_input_ids': adv_tokenized['input_ids'],
        'adv_attention_mask': adv_tokenized['attention_mask'],
        'orig_input_ids': orig_tokenized['input_ids'],
        'orig_attention_mask': orig_tokenized['attention_mask']
    }

class CustomCollator:
    def __init__(self, tokenizer: PreTrainedTokenizerBase):
        self.tokenizer = tokenizer

    def __call__(self, batch):
        """
        Custom collator function for the dataset.

        Args:
        batch: Batch of tokenized adversarial and original texts and labels

        Returns:
        dictionary containing:
            1. padded adversarial tokenized input_ids and attention_mask
            2. padded original tokenized input_ids and attention_mask
            3. labels
        """

        adv_input_ids = [item['adv_input_ids'] for item in batch]
        adv_attention_mask = [item['adv_attention_mask'] for item in batch]
        orig_input_ids = [item['orig_input_ids'] for item in batch]
        orig_attention_mask = [item['orig_attention_mask'] for item in batch]
        labels = [item['label'] for item in batch]

        # Pad the input_ids and attention_mask
        adv_input_ids_padded = pad_sequence([torch.tensor(ids) for ids in adv_input_ids],
                                            batch_first=True, padding_value=self.tokenizer.pad_token_id)
        adv_attention_mask_padded = pad_sequence([torch.tensor(mask) for mask in adv_attention_mask],
                                                 batch_first=True, padding_value=0)
        orig_input_ids_padded = pad_sequence([torch.tensor(ids) for ids in orig_input_ids],
                                             batch_first=True, padding_value=self.tokenizer.pad_token_id)
        orig_attention_mask_padded = pad_sequence([torch.tensor(mask) for mask in orig_attention_mask],
                                                  batch_first=True, padding_value=0)

        return {
            'adv_tokenized': {
                'input_ids': adv_input_ids_padded,
                'attention_mask': adv_attention_mask_padded
            },
            'orig_tokenized': {
                'input_ids': orig_input_ids_padded,
                'attention_mask': orig_attention_mask_padded
            },
            'labels': torch.tensor(labels)
        }
    
def gpu_memory():
    """
    Return the GPU memory usage and total memory available in GB for all available GPUs.
    """
    gpu_memory_used = 0
    gpu_memory_total = 0

    for i in range(torch.cuda.device_count()):
        gpu_memory_used += torch.cuda.memory_allocated(i)
        gpu_memory_total += torch.cuda.get_device_properties(i).total_memory

    return gpu_memory_used / 1024**3, gpu_memory_total / 1024**3

if __name__ == "__main__":

    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_id", type=str, default="EleutherAI/pythia-14m", help="Model ID")
    parser.add_argument("--results_file", type=str, default="results/pythia-results.json", help="Results file")
    parser.add_argument("--task", type=str, default="sst2", help="Task")
    parser.add_argument("--ft_mode", type=str, choices=['weak', 'strong', 'wts-naive', 'wts-aux-loss'], default="weak", help="Finetuning mode")
    parser.add_argument("--weak_labels_file", type=str, default="weak_labels/pythia-14m.json", help="Weak labels file for weak-to-strong finetuning")
    parser.add_argument("--num_epochs", type=float, default=6.0, help="Number of epochs")
    parser.add_argument("--batch_size", type=int, default=30, help="Batch size")
    parser.add_argument("--lambda_coeff", type=float, default=0.3,
                        help="Coefficient for controlling the balance between adversarial and clean losses. Loss = lambda * loss_adv + (1 - lambda) * loss_orig")
    parser.add_argument("--warm_up", type=float, default=0.2, help="Warmup period for the auxiliary confidence loss")
    parser.add_argument("--alpha_max", type=float, default=0.1, help="Maximum value of alpha for the auxiliary confidence loss")
    parser.add_argument("--validate_on", type=str, choices=['adversarial', 'original'], default='adversarial',
                        help="Whether to use adversarial or clean samples for validation")
    parser.add_argument("--tag", type=str, default="", help="Tag for the experiment")
    # parser.add_argument("--glue", action="store_true", help="Use GLUE dataset")
    parser.add_argument("--glue_ds", type=str, default='advgluepp', choices=['glue', 'advglue', 'advgluepp'],
                        help='Dataset to use: glue, advglue or advgluepp')
    args = parser.parse_args()

    model_id = args.model_id
    results_file = args.results_file
    task = args.task
    ft_mode = args.ft_mode
    weak_labels_file = args.weak_labels_file
    num_epochs = args.num_epochs
    lambda_coeff = args.lambda_coeff
    warm_up = args.warm_up
    alpha_max = args.alpha_max
    validate_on = args.validate_on
    tag = args.tag
    # glue = args.glue
    glue_ds = args.glue_ds

    max_len = 1024
    batch_size = args.batch_size

    # Print experiment details
    print("\n* * * * * Experiment Details * * * * *")
    print("Model ID:", model_id)
    print("Results file:", results_file)
    print("Task:", task)
    print("Finetuning mode:", ft_mode)
    if ft_mode == 'wts-naive' or ft_mode == 'wts-aux-loss':
        print("Weak labels file:", weak_labels_file)
    if ft_mode == 'wts-aux-loss':
        print("Alpha max:", alpha_max)
    print("Validate on:", validate_on)
    print("Lambda coefficient:", lambda_coeff)
    print("Warmup period:", warm_up)
    print("Number of epochs:", num_epochs)
    print("Number of available GPUs:", torch.cuda.device_count())
    print("GPU names:", [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())])
    print("Dataset:", glue_ds)
    print("Batch size:", batch_size)
    print("* * * * * * * * * * * * * * * * * * * *\n", flush=True)

    # Load data
    if ft_mode == 'wts-naive' or ft_mode == 'wts-aux-loss':
        data = load_data(weak_labels_file)[task]
        print("Weak labels:", len(data))

    else:
        if glue_ds == 'glue':
            data = load_data('data/glue_train.json')[task]
            print("Training samples:", len(data))
        elif glue_ds == 'advglue':
            data = load_data('data/advglue_train.json')[task]
            print("Training samples:", len(data))
        elif glue_ds == 'advgluepp':
            data = load_data('data/advgluepp_train.json')[task]
            print("Training samples:", len(data))
        else:
            raise ValueError(f"Dataset {glue_ds} not supported")

        data = data_preprocessing(data, task)

    # Load class weights from dataset info file
    if glue_ds == 'glue':
        class_wts_file = 'data/glue_info.json'
    elif glue_ds == 'advglue':
        class_wts_file = 'data/advglue_info.json'
    elif glue_ds == 'advgluepp':
        class_wts_file = 'data/advgluepp_info.json'
    else:
        raise ValueError(f"Dataset {glue_ds} not supported")

    with open(class_wts_file, 'r') as f:
        class_wts = json.load(f)['class_wts']
        # print("Class weights:", class_wts)

    output_dir = f"models/{model_id}-{task}{tag}"

    # Load model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id)
    model = AutoModelForSequenceClassification.from_pretrained(model_id,
                                                               device_map="auto",
                                                               num_labels=num_classes[task],
                                                               )
    # print(f"GPU memory used: {gpu_memory()[0]:.2f} / {gpu_memory()[1]:.2f} GB")

    # Split into training and validation datasets
    data_len = len(data)
    train_len = int(0.8 * data_len)
    np.random.shuffle(data)

    train_data = data[:train_len]
    val_data = data[train_len:]

    # Create Huggingface datasets
    ds = DatasetDict({
        'train': Dataset.from_dict({'adv_text': [item['adv_text'] for item in train_data],
                                    'orig_text': [item['orig_text'] for item in train_data],
                                    'label': [item['label'] for item in train_data]}),
        'val': Dataset.from_dict({'adv_text': [item['adv_text'] for item in val_data],
                                  'orig_text': [item['orig_text'] for item in val_data],
                                  'label': [item['label'] for item in val_data]})
    })

    # Set padding token to eos token
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id

    # Tokenize the sentences in the datasets
    tokenized_ds = ds.map(lambda samples: tokenize_adv_orig_text(tokenizer, samples), batched=True)

    # Fine-tune the model
    print("\nFine-tuning the model...\n", flush=True)
    # trainer = transformers.Trainer(
    trainer = WTS_Trainer(
        model=model,
        train_dataset=tokenized_ds['train'],      # data["train"],
        eval_dataset=tokenized_ds['val'],
        args=transformers.TrainingArguments(
            remove_unused_columns=False,
            per_device_train_batch_size=batch_size,
            per_device_eval_batch_size=batch_size,
            # gradient_accumulation_steps=4,
            gradient_checkpointing=True,
            num_train_epochs=num_epochs,
            evaluation_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
            save_total_limit=1,     # Only save the best model
            warmup_steps=2,
            # max_steps=10,
            learning_rate=1e-5,
            # fp16=True,
            logging_steps=1,
            output_dir=output_dir,
            # optim="paged_adamw_8bit"
        ),
        data_collator=CustomCollator(tokenizer),
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
        ft_mode=ft_mode,
        alpha_max=alpha_max,       # Maximum value of alpha for the auxiliary confidence loss
        lambda_coeff=lambda_coeff,     # Coefficient for controlling the balance between adversarial and clean losses
        warm_up=warm_up,        # Warmup period for the auxiliary confidence loss
        class_wts=class_wts[task],
        use_samples=validate_on   # Whether to use adversarial or clean samples for validation
    )
    model.config.use_cache = False  # silence the warnings. Please re-enable for inference!
    trainer.train()

    # Load test data
    if glue_ds == 'glue':
        test_data = load_data('data/glue_test.json')[task]
        print("Test samples:", len(test_data))
    elif glue_ds == 'advglue':
        test_data = load_data('data/advglue_test.json')[task]
        print("Test samples:", len(test_data))
    elif glue_ds == 'advgluepp':
        test_data = load_data('data/advgluepp_test.json')[task]
        print("Test samples:", len(test_data))
    else:
        raise ValueError(f"Dataset {glue_ds} not supported")
    
    test_data = data_preprocessing(test_data, task)
    test_ds = Dataset.from_dict({'adv_text': [item['adv_text'] for item in test_data],
                                 'orig_text': [item['orig_text'] for item in test_data],
                                 'label': [item['label'] for item in test_data]})
    
    tokenized_test_ds = test_ds.map(lambda samples: tokenize_adv_orig_text(tokenizer, samples), batched=True)
    # print("Tokenized test dataset:")
    # print(tokenized_test_ds)

    # Evaluate the model
    for use_samples in ['adversarial', 'original']:
        print(f'\nEvaluating the model on {use_samples} samples...\n', flush=True)
        trainer.use_samples = use_samples
        eval_results = trainer.evaluate(tokenized_test_ds)
        print(eval_results)

        # Create results directory if it doesn't exist
        if not os.path.exists(os.path.dirname(results_file)):
            os.makedirs(os.path.dirname(results_file))

        # Load the results file
        try:
            with open(results_file, 'r') as f:
                results = json.load(f)
        except:
            results = {}

        if use_samples not in results:
            results[use_samples] = {}

        # Add task results to the dictionary
        if task not in results[use_samples]:
            results[use_samples][task] = {}

        task_results = results[use_samples][task]

        if ft_mode == 'weak':
            task_results['Weak Performance'] = eval_results['eval_accuracy'] * 100
        elif ft_mode == 'strong':
            task_results['Strong Performance'] = eval_results['eval_accuracy'] * 100
        elif ft_mode == 'wts-naive':
            task_results['WTS-Naive'] = eval_results['eval_accuracy'] * 100
        elif ft_mode == 'wts-aux-loss':
            task_results['WTS-Aux-Loss'] = eval_results['eval_accuracy'] * 100

        results[use_samples][task] = task_results

        # Save the results
        with open(results_file, 'w') as f:
            json.dump(results, f, indent=2)

        print("Results saved to:", results_file)

    # Get predictions on holdout data
    if ft_mode == 'weak':
        if glue_ds == 'glue':
            data_holdout = load_data('data/glue_holdout.json')[task]
            print("Holdout samples:", len(data_holdout))
        elif glue_ds == 'advglue':
            data_holdout = load_data('data/advglue_holdout.json')[task]
            print("Holdout samples:", len(data_holdout))
        elif glue_ds == 'advgluepp':
            data_holdout = load_data('data/advgluepp_holdout.json')[task]
            print("Holdout samples:", len(data_holdout))
        else:
            raise ValueError(f"Dataset {glue_ds} not supported")

        data_holdout = data_preprocessing(data_holdout, task)
        ds_holdout = Dataset.from_dict({'adv_text': [item['adv_text'] for item in data_holdout],
                                        'orig_text': [item['orig_text'] for item in data_holdout],
                                        'label': [item['label'] for item in data_holdout]})
        tokenized_ds_holdout = ds_holdout.map(lambda samples: tokenize_adv_orig_text(tokenizer, samples), batched=True)
        # tokenized_ds_holdout = ds_holdout.map(lambda samples: tokenizer(samples['text'], truncation=True, max_length=max_len), batched=True)

        print("\nGenerating weak labels on holdout data...\n", flush=True)
        predictions = trainer.predict(tokenized_ds_holdout)
        
        # Get labels from predictions
        preds = np.argmax(predictions.predictions, axis=1)

        weak_labels_dict_list = []
        for i, item in enumerate(data_holdout):
            weak_labels_dict_list.append({
                'adv_text': item['adv_text'],
                'orig_text': item['orig_text'],
                'label': int(preds[i])
            })

        # Save weak labels to file
        weak_labels_file = f'weak_labels/{glue_ds}/{model_id}{tag}.json'

        # Create directory if it doesn't exist
        if not os.path.exists(os.path.dirname(weak_labels_file)):
            os.makedirs(os.path.dirname(weak_labels_file))

        # Load the weak labels file
        try:
            with open(weak_labels_file, 'r') as f:
                weak_labels = json.load(f)
        except:
            weak_labels = {}

        weak_labels[task] = weak_labels_dict_list

        # Save the weak labels
        with open(weak_labels_file, 'w') as f:
            json.dump(weak_labels, f, indent=2)

        print("Weak labels saved to:", weak_labels_file, flush=True)
