# Standard library imports
import os
import sys
import pickle
import time
import gc

# Third-party imports
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass, field
from typing import Optional
import torch
from transformers import HfArgumentParser
from skdim import id as skdim
from tqdm import tqdm


def compute_intrinsic_dimension(embeddings, method="two_nn"):
    """
    Compute the intrinsic dimension of embeddings using the specified method.

    Parameters:
    -----------
    embeddings : numpy.ndarray
        A 3D array of shape (n_layers, n_points, n_features), where n_layers is
        the number of layers, n_points is the number of data points, and
        n_features is the dimensionality of each point.

    method : str, optional
        The method used to estimate intrinsic dimensionality. Currently supports "two_nn".
        Defaults to "two_nn".

    Returns:
    --------
    intrinsic_dims : list
        A list of intrinsic dimensionality estimates for each layer of embeddings.
    """

    # Validate the method input to ensure it is supported
    if method != "two_nn":
        raise ValueError(
            f"Method '{method}' is not supported. Currently, only 'two_nn' is implemented."
        )

    intrinsic_dims = []

    # Compute the intrinsic dimension
    for embedding in tqdm(embeddings):
        nn_estimator = skdim.TwoNN()
        id_estimate = nn_estimator.fit_transform(embedding)
        intrinsic_dims.append(id_estimate)

    return intrinsic_dims


def plot_intrinsic_dims(
    intrinsic_dims, std=None, num_batches=None, file_path="intrinsic_dims.png"
):
    """
    Plots the intrinsic dimensions of hidden states with an optional confidence interval.

    Parameters:
    intrinsic_dims (list or numpy array): A list or array containing the intrinsic dimensions for each hidden state.
    std (list or numpy array, optional): A list or array containing the standard deviations of the intrinsic dimensions for each hidden state. Default is None.
    num_batches (int, optional): The number of batches used to compute the intrinsic dimensions. This is used to calculate the 95% confidence interval if std is provided. Default is None.
    file_path (str, optional): The file path where the plot will be saved. Default is "intrinsic_dims.png".

    Returns:
    None

    The function creates and saves a plot ("intrinsic_dims.png") with the following features:
    - Line plot of intrinsic dimensions with markers for each hidden state.
    - 95% confidence interval as a shaded area if std and num_batches are provided.
    - A red marker indicating the minimum intrinsic dimension with its value annotated.
    """
    min_id = np.argmin(intrinsic_dims)

    plt.figure(figsize=(10, 5))
    plt.plot(range(1, len(intrinsic_dims) + 1), intrinsic_dims, marker="o")

    if std is not None and num_batches is not None:
        plt.fill_between(
            range(1, len(intrinsic_dims) + 1),
            intrinsic_dims - 1.96 * std / np.sqrt(num_batches),
            intrinsic_dims + 1.96 * std / np.sqrt(num_batches),
            alpha=0.3,
            label="95% CI",
        )

    plt.plot(min_id + 1, intrinsic_dims[min_id], "ro", label="Minimum ID")
    plt.annotate(
        f"{intrinsic_dims[min_id]:.2f}",
        (min_id + 1, intrinsic_dims[min_id]),
        textcoords="offset points",
        xytext=(0, 10),
        ha="center",
    )
    plt.xlabel("Hidden state")
    plt.ylabel("Intrinsic dimension")
    plt.title("Intrinsic dimensions of hidden states")
    plt.grid()
    plt.legend(loc="best")
    plt.savefig(file_path)


# Aadaptive ranks and alphas
def set_adaptive_ranks_alphas(id_path, alpha_rank_ratio=2):
    """
    Sets the adaptive ranks and alphas based on the intrinsic dimensions of the hidden states -- Llama2.

    Parameters:
    id_path (str): The path to the file containing the intrinsic dimensions of the hidden states.
    ALPHA_RANK_RATIO (int, optional): The ratio of alpha to rank. Default is 1.

    Returns:
    None

    The function loads the intrinsic dimensions of the hidden states from the saved file "final_ids.npy" and rounds them.
    It then creates dictionaries for the rank pattern and alpha pattern for the adaptive LoRA model.
    The dictionaries are saved as "rank_pattern.pkl" and "alpha_pattern.pkl" respectively.
    """
    ranks = np.round(np.load(id_path)).astype(int)
    

    rank_pattern = {}
    for i in range(len(ranks)-1):
        if int(ranks[i+1] - ranks[i]) <= 0:
            rank = 1
        else:
            rank = int(ranks[i+1] - ranks[i]) + 1
        rank_pattern[f"deberta.encoder.layer.{i}.attention.self.query_proj"] = rank
        rank_pattern[f"deberta.encoder.layer.{i}.attention.self.key_proj"] = rank
        rank_pattern[f"deberta.encoder.layer.{i}.attention.self.value_proj"] = rank
        rank_pattern[f"deberta.encoder.layer.{i}.attention.output.dense"] = rank

    alpha_pattern = {
        key: int(value * alpha_rank_ratio) for key, value in rank_pattern.items()
    }
    
    return rank_pattern, alpha_pattern


def get_total_glora_budget(rank_pattern):
    """
    Calculates the total budget required for the adaptive LoRA model.

    Parameters:
    rank_pattern (dict): A dictionary containing the ranks of the hidden states.

    Returns:
    int: The total budget required for the adaptive LoRA model.
    """
    return sum(rank_pattern.values())


def get_mean_rank(rank_pattern):
    """
    Calculates the mean rank of the hidden states.

    Parameters:
    rank_pattern (dict): A dictionary containing the ranks of the hidden states.
    n_layers (int, optional): The number of layers to finetune in the model per trasnsformer block. Default is 4.

    Returns:
    float: The mean rank of the hidden states.
    """
    return np.mean(list(rank_pattern.values()))


# Clearing CUDA Cache
def clear_cache():
    """
    Clears the CUDA cache and performs garbage collection to free up memory.

    This function performs the following steps:
    1. Clears the CUDA cache using torch.cuda.empty_cache().
    2. Performs garbage collection using gc.collect().
    3. Pauses execution for 5 seconds using time.sleep(5) to ensure memory is freed.

    Returns:
    None

    Example usage:
    clear_cache()
    """
    torch.cuda.empty_cache()
    gc.collect()
    time.sleep(5)


# Dictionary mapping GLUE tasks to their prompt format keys
GLUE_TASK_TO_KEYS = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mnli-m": ("premise", "hypothesis"),
    "mnli-mm": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}


def get_task_keys(task):
    """
    Retrieve the prompt format keys for a given GLUE task.

    Parameters:
    -----------
    task : str
        The name of the GLUE task.

    Returns:
    --------
    tuple
        A tuple of prompt format keys corresponding to the task.
    """
    if task not in GLUE_TASK_TO_KEYS:
        raise ValueError(f"Task '{task}' not found in GLUE_TASK_TO_KEYS.")
    return GLUE_TASK_TO_KEYS.get(task)


# define dataset arguments
@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to the data used for model training and evaluation.

    Attributes:
    -----------
    task_name : Optional[str]
        The name of the GLUE task to train on. Must be one of the tasks defined in GLUE_TASK_TO_KEYS.
    """

    task_name: Optional[str] = field(
        default=None,
        metadata={
            "help": "The name of the task to train on: "
            + ", ".join(GLUE_TASK_TO_KEYS.keys())
        },
    )

    def __post_init__(self):
        """
        Convert task name to lowercase and check if it is a valid GLUE task.
        Raises:
        -------
        ValueError:
            If the task_name is not one of the known GLUE tasks.
        """
        if self.task_name is not None:
            self.task_name = self.task_name.lower()
            if self.task_name not in GLUE_TASK_TO_KEYS:
                valid_tasks = ", ".join(GLUE_TASK_TO_KEYS.keys())
                raise ValueError(
                    f"Unknown task '{self.task_name}'. Please choose from: {valid_tasks}"
                )


def main():
    # Define the directory to save the intrinsic dimensions
    intrinsic_dims_dir = "./intrinsic_dims"
    os.makedirs(intrinsic_dims_dir, exist_ok=True)

    # Initialize parser for DataTrainingArguments
    parser = HfArgumentParser(DataTrainingArguments)

    # Check if a JSON file path is passed as the sole argument
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # Parse arguments from the JSON file
        data_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        # Parse arguments from the command line
        (data_args,) = parser.parse_args_into_dataclasses()

    # Standardize task name for any variation of "mnli"
    output_dir = data_args.task_name
    if data_args.task_name and "mnli" in data_args.task_name:
        data_args.task_name = "mnli"

    # Load the embeddings from the file corresponding to the task
    embeddings_file = (
        f"./dataset_activations/activations_full_{output_dir}.npy"
    )
    embeddings = np.load(embeddings_file)

    # Compute intrinsic dimensions
    intrinsic_dims = compute_intrinsic_dimension(embeddings)

    Save the intrinsic dimensions to a numpy file
    intrinsic_dims_file = f"{intrinsic_dims_dir}/ID_{output_dir}.npy"
    np.save(intrinsic_dims_file, intrinsic_dims)

    # Plot intrinsic dimensions, skipping the first layer (embedding layer output)
    plot_intrinsic_dims(
        intrinsic_dims[1:],
        file_path=f"{intrinsic_dims_dir}/ID_{output_dir}.png",
    )

    # Confirm successful save
    print(f"Intrinsic dimensions saved to {intrinsic_dims_file}")

    # Set adaptive ranks and alphas based on intrinsic dimensions
    for alpha_rank_ratio in [2, 4, 8, 16, 32]:
        rank_pattern, alpha_pattern = set_adaptive_ranks_alphas(
            intrinsic_dims_file, alpha_rank_ratio=alpha_rank_ratio
        )

        # Directory for saving rank pattern files
        rank_pattern_dir = "./rank_pattern"
        os.makedirs(rank_pattern_dir, exist_ok=True)
        rank_pattern_file = f"{rank_pattern_dir}/rank_pattern_{output_dir}.pkl"

        # Save the rank pattern as a pickle file
        with open(rank_pattern_file, "wb") as file:
            pickle.dump(rank_pattern, file)

        # Directory for saving alpha pattern files
        alpha_pattern_dir = f"./alpha_pattern/{output_dir}"
        os.makedirs(alpha_pattern_dir, exist_ok=True)
        alpha_pattern_file = f"{alpha_pattern_dir}/alpha_pattern_{output_dir}_{alpha_rank_ratio}.pkl"

        # Save the alpha pattern as a pickle file
        with open(alpha_pattern_file, "wb") as file:
            pickle.dump(alpha_pattern, file)

        # Print confirmation messages
        print(f"Rank pattern saved to {rank_pattern_file}")
        print(f"Alpha pattern saved to {alpha_pattern_file}")

    # Calculate the mean rank from the rank pattern
    mean_rank = get_mean_rank(rank_pattern)
    print(f"Mean rank: {mean_rank}")

    # Define the directory to save the mean rank file
    mean_rank_dir = "./mean_rank"
    os.makedirs(mean_rank_dir, exist_ok=True)

    # Save the mean rank as a pickle file
    mean_rank_file = f"{mean_rank_dir}/mean_rank_{output_dir}.pkl"
    with open(mean_rank_file, "wb") as file:
        pickle.dump(mean_rank, file)

    # Print confirmation message
    print(f"Mean rank saved to {mean_rank_file}")


if __name__ == "__main__":
    main()
