# Standard library imports
import os
import sys
from typing import Optional
from dataclasses import dataclass, field

# Third-party imports
import numpy as np
import torch
from tqdm import tqdm
from datasets import load_from_disk
from transformers import (
    HfArgumentParser,
    TrainingArguments,
    AutoConfig,
    AutoTokenizer,
    PretrainedConfig,
    AutoModelForSequenceClassification,
)
from accelerate import PartialState

# Define the path to retrieve processed datasets
save_path = "./processed_datasets"

# Dictionary mapping GLUE tasks to their prompt format keys
GLUE_TASK_TO_KEYS = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mnli-m": ("premise", "hypothesis"),
    "mnli-mm": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}


@dataclass
class DataTrainingArguments:
    """
    Arguments for configuring data input for model training and evaluation.

    Attributes:
    -----------
    task_name : Optional[str]
        The name of the task to train on. Must be one of the tasks defined in `GLUE_TASK_TO_KEYS`.
    max_seq_length : int
        The maximum length of input sequences after tokenization. Sequences longer than this will be truncated,
        and shorter sequences will be padded.
    pad_to_max_length : bool
        If True, all samples are padded to `max_seq_length`. If False, padding is applied dynamically based
        on the maximum length in the batch.
    max_train_samples : Optional[int]
        If set, truncates the number of training samples to this value for debugging or quicker training.
    max_val_samples : Optional[int]
        If set, truncates the number of validation samples to this value for debugging or quicker training.
    max_test_samples : Optional[int]
        If set, truncates the number of test samples to this value for debugging or quicker training.
    """

    task_name: Optional[str] = field(
        default=None,
        metadata={
            "help": f"The name of the task to train on: {', '.join(GLUE_TASK_TO_KEYS.keys())}"
        },
    )
    max_seq_length: int = field(
        default=128,
        metadata={
            "help": "The maximum length of input sequences after tokenization. Sequences longer than this will be truncated, and shorter sequences will be padded."
        },
    )
    pad_to_max_length: bool = field(
        default=True,
        metadata={
            "help": "If True, all samples are padded to `max_seq_length`. If False, padding is applied dynamically based on the batch's maximum length."
        },
    )
    max_train_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "If set, truncates the number of training samples to this value for debugging or quicker training."
        },
    )
    max_val_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "If set, truncates the number of validation samples to this value for debugging or quicker training."
        },
    )
    max_test_samples: Optional[int] = field(
        default=None,
        metadata={
            "help": "If set, truncates the number of test samples to this value for debugging or quicker training."
        },
    )

    def __post_init__(self):
        if self.task_name:
            self.task_name = self.task_name.lower()
            if self.task_name not in GLUE_TASK_TO_KEYS:
                raise ValueError(
                    f"Unknown task '{self.task_name}'. Please pick one from: {', '.join(GLUE_TASK_TO_KEYS.keys())}."
                )


@dataclass
class ModelArguments:
    """
    Arguments for configuring which model, config, and tokenizer to fine-tune from.

    Attributes:
    -----------
    model_name_or_path : str
        The path to the pretrained model or model identifier from huggingface.co/models.
    config_name : Optional[str]
        The path or name of the pretrained configuration if it differs from `model_name_or_path`.
    tokenizer_name : Optional[str]
        The path or name of the pretrained tokenizer if it differs from `model_name_or_path`.
    cache_dir : Optional[str]
        Directory where pretrained models downloaded from huggingface.co should be stored.
    use_fast_tokenizer : bool
        Whether to use a fast tokenizer (backed by the tokenizers library) or not.
    model_revision : str
        The specific version of the model to use (can be a branch name, tag name, or commit ID).
    token : bool
        Whether to use the token generated with `huggingface-cli login` for accessing private models.
    """

    model_name_or_path: str = field(
        metadata={
            "help": "Path to pretrained model or model identifier from huggingface.co/models"
        }
    )
    config_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained config name or path if different from model_name_or_path"
        },
    )
    tokenizer_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "Pretrained tokenizer name or path if different from model_name_or_path"
        },
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={
            "help": "Directory to store pretrained models downloaded from huggingface.co"
        },
    )
    use_fast_tokenizer: bool = field(
        default=False,
        metadata={
            "help": "Whether to use a fast tokenizer (backed by the tokenizers library) or not."
        },
    )
    model_revision: str = field(
        default="main",
        metadata={
            "help": "The specific version of the model to use (can be a branch name, tag name, or commit ID)."
        },
    )
    token: bool = field(
        default=False,
        metadata={
            "help": "Use the token from `huggingface-cli login` for private models."
        },
    )


def main():
    # Initialize the parser with the argument classes
    parser = HfArgumentParser(
        (ModelArguments, DataTrainingArguments, TrainingArguments)
    )

    # Determine how to parse arguments based on input
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If a JSON file path is provided, parse arguments from the JSON file
        json_file_path = os.path.abspath(sys.argv[1])
        model_args, data_args, training_args = parser.parse_json_file(
            json_file=json_file_path
        )
    else:
        # Otherwise, parse arguments from command line
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Standardize the task name for different variations of "mnli"
    output_dir = data_args.task_name
    if data_args.task_name and "mnli" in data_args.task_name:
        data_args.task_name = "mnli"

    # Load the preprocessed raw_datasets from disk
    raw_datasets = load_from_disk(
        os.path.join(training_args.output_dir, save_path, output_dir)
    )

    # Determine if the task is a regression task and set labels accordingly
    if data_args.task_name:
        # Check if the task is "stsb", which is a regression task
        is_regression = data_args.task_name == "stsb"
        if not is_regression:
            # For classification tasks, get label names and count
            label_list = raw_datasets["train"].features["label"].names
            num_labels = len(label_list)
        else:
            # For regression tasks, there is only one label
            num_labels = 1
    else:
        # Default behavior if task_name is not provided
        # Determine if the task is regression based on label dtype
        is_regression = raw_datasets["train"].features["label"].dtype in [
            "float32",
            "float64",
        ]
        if is_regression:
            num_labels = 1
        else:
            # For classification tasks, get unique labels and count
            label_list = sorted(raw_datasets["train"].unique("label"))
            num_labels = len(label_list)

    # Load pretrained model and tokenizer
    config = AutoConfig.from_pretrained(
        (
            model_args.config_name
            if model_args.config_name
            else model_args.model_name_or_path
        ),
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        token=True if model_args.token else None,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        (
            model_args.tokenizer_name
            if model_args.tokenizer_name
            else model_args.model_name_or_path
        ),
        cache_dir=model_args.cache_dir,
        use_fast=model_args.use_fast_tokenizer,
        revision=model_args.model_revision,
        token=True if model_args.token else None,
    )
    model = AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path,
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        token=True if model_args.token else None,
    )

    # Preprocess raw datasets based on task name
    if data_args.task_name:
        # Retrieve the appropriate sentence keys for the given task
        sentence1_key, sentence2_key = GLUE_TASK_TO_KEYS.get(
            data_args.task_name, (None, None)
        )
    else:
        # Determine default sentence keys if task_name is not provided
        non_label_column_names = [
            name for name in raw_datasets["train"].column_names if name != "label"
        ]

        # Assign sentence keys based on available columns
        if (
            "sentence1" in non_label_column_names
            and "sentence2" in non_label_column_names
        ):
            sentence1_key, sentence2_key = "sentence1", "sentence2"
        elif len(non_label_column_names) >= 2:
            sentence1_key, sentence2_key = non_label_column_names[:2]
        elif non_label_column_names:
            sentence1_key = non_label_column_names[0]
            sentence2_key = None
        else:
            sentence1_key, sentence2_key = None, None

    # Determine padding strategy based on user arguments
    if data_args.pad_to_max_length:
        padding = "max_length"
    else:
        # Dynamic padding will be applied during batch creation based on the maximum sequence length in each batch
        padding = False

    # Some models have set the order of the labels to use, so let's make sure we do use it.
    label_to_id = None

    # Check if model label configuration differs from the default
    if (
        model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
        and data_args.task_name
        and not is_regression
    ):
        # Normalize label names to lowercase for comparison
        label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
        # Compare sorted label names from the model and dataset
        if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
            label_to_id = {
                i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)
            }
        else:
            print(
                "Your model seems to have been trained with labels, but they don't match the dataset:",
                f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}.",
                "\nIgnoring the model labels as a result.",
            )
    elif data_args.task_name is None and not is_regression:
        # If no task name and not a regression task, create a default label_to_id mapping
        label_to_id = {v: i for i, v in enumerate(label_list)}

    if label_to_id is not None:
        model.config.label2id = label_to_id
        model.config.id2label = {id: label for label, id in config.label2id.items()}
    elif data_args.task_name is not None and not is_regression:
        model.config.label2id = {l: i for i, l in enumerate(label_list)}
        model.config.id2label = {id: label for label, id in config.label2id.items()}

    if data_args.max_seq_length > tokenizer.model_max_length:
        print(
            f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
            f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
        )
    max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)

    def preprocess_function(examples):
        """
        Preprocess the input examples by tokenizing and mapping labels to IDs.

        Args:
            examples (dict): Dictionary containing the texts and optionally labels.

        Returns:
            dict: Tokenized texts and mapped labels.
        """
        # Determine if both sentence keys are provided or just one
        texts = (
            (examples[sentence1_key],)
            if sentence2_key is None
            else (examples[sentence1_key], examples[sentence2_key])
        )

        # Tokenize the texts with padding, truncation, and max length
        result = tokenizer(
            *texts, padding=padding, max_length=max_seq_length, truncation=True
        )

        # Map labels to IDs if applicable
        if label_to_id is not None and "label" in examples:
            result["label"] = [label_to_id.get(l, -1) for l in examples["label"]]

        return result

    # Tokenize the dataset using the preprocess_function
    raw_datasets = raw_datasets.map(
        preprocess_function, batched=True, desc="Running tokenizer on dataset"
    )

    # Ensure the 'train' dataset is available
    if "train" not in raw_datasets:
        raise ValueError("--do_train requires a train dataset")

    # Extract the train dataset
    train_dataset = raw_datasets["train"]

    # Optionally truncate the training dataset based on max_train_samples
    if data_args.max_train_samples is not None:
        max_train_samples = min(len(train_dataset), data_args.max_train_samples)
        train_dataset = train_dataset.select(range(max_train_samples))

    # Ensure the 'train' dataset is available
    if "validation" not in raw_datasets and "validation_matched" not in raw_datasets:
        raise ValueError("--do_eval requires a validation dataset")

    # Extract the validation dataset
    val_dataset = raw_datasets["validation"]

    # Optionally truncate the validation dataset based on max_val_samples
    if data_args.max_val_samples is not None:
        max_val_samples = min(len(val_dataset), data_args.max_val_samples)
        val_dataset = val_dataset.select(range(max_val_samples))

    # Ensure the 'test' dataset is available
    if "test" not in raw_datasets and "test_matched" not in raw_datasets:
        raise ValueError("--do_predict requires a test dataset")

    # Extract the test dataset
    test_dataset = raw_datasets["test"]

    # Optionally truncate the test dataset based on max_test_samples
    if data_args.max_test_samples is not None:
        max_test_samples = min(len(test_dataset), data_args.max_test_samples)
        test_dataset = test_dataset.select(range(max_test_samples))

    # Store the activations of hidden states for each shard
    def store_hidden_states(
        model,
        train_dataset,
        batch_size,
        device,
        path_to_activations,
    ):
        """
        Store hidden states activations from a model into numpy arrays.

        Args:
            model (torch.nn.Module): The model from which to extract hidden states.
            train_dataset (Dataset): The dataset to process.
            batch_size (int): Batch size for data loading.
            device (str): Device to use for model computation.
            path_to_activations (str): Directory path to save the activations numpy array.

        Returns:
            None
        """
        model = model.to(device)
        print(f"Size of training data for {data_args.task_name}: {len(train_dataset)}")

        # Store activations in a list
        print("Storing activations...")
        activations = []

        # Manually batch the data
        for i in tqdm(
            range(0, len(train_dataset), batch_size), desc="Processing Batches"
        ):
            inputs = {
                k: torch.tensor(v).to(device)
                for k, v in train_dataset[i : i + batch_size].items()
                if k in ["input_ids", "attention_mask"]
            }
            with torch.no_grad():
                outputs = model(**inputs, output_hidden_states=True)

            # Move hidden states to CPU and convert them to numpy
            hidden_states = [
                state.detach().cpu().numpy() for state in outputs.hidden_states
            ]

            # Append the processed hidden states to the activations list
            activations.append(np.stack(hidden_states).mean(axis=2))

            del outputs
            torch.cuda.empty_cache()

        # Ensure the directory exists
        os.makedirs(path_to_activations, exist_ok=True)

        # Save activations to file
        np.save(
            os.path.join(
                path_to_activations, f"activations_full_{output_dir}.npy"
            ),
            np.concatenate(activations, axis=1),
        )
        print(f"Activations saved to {path_to_activations}")

    device = PartialState().process_index
    batch_size = 32
    store_hidden_states(
        model=model,
        train_dataset=train_dataset,
        batch_size=batch_size,
        device=device,
        path_to_activations="./dataset_activations",
    )


if __name__ == "__main__":
    main()
