import sys
import os

# Add parent directory to path for imports
sys.path.append(
    os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)

import wandb
import gc
import torch
from transformers import AutoTokenizer
from datasets import load_dataset
from nnsight import LanguageModel
from itertools import combinations
import numpy as np
from dataset_utils_multilangevol import load_dataset as load_self_dataset
import argparse


def calculate_sads_score(cosine_similarities, language_columns):
    same_semantics_sim_sum = 0.0
    same_semantics_count = 0
    different_semantics_sim_sum = 0.0
    different_semantics_count = 0

    for lang_i, lang_j in combinations(language_columns, 2):
        if lang_i == lang_j:  # Skip same-language pairs
            continue

        sim_matrix = cosine_similarities[(lang_i, lang_j)]
        # Validate dimensions
        n, m = sim_matrix.shape
        if n != m:
            raise ValueError(
                f"Dimension mismatch for {lang_i}-{lang_j} pair: {n}x{m}. "
                "This suggests misalignment in language vectors."
            )

        # Get diagonal elements (same semantics)
        diag_elements = torch.diag(sim_matrix)
        same_semantics_sim_sum += torch.sum(diag_elements).item()
        same_semantics_count += n

        # Get off-diagonal elements (different semantics)
        mask = ~torch.eye(n, dtype=bool, device=sim_matrix.device)
        off_diag = sim_matrix[mask]
        different_semantics_sim_sum += torch.sum(off_diag).item()
        different_semantics_count += n * n - n

    # Check if we have data to analyze
    if same_semantics_count == 0 or different_semantics_count == 0:
        raise ValueError("No valid pairs found for SADS calculation")

    # Calculate averages properly based on counts
    avg_same = same_semantics_sim_sum / same_semantics_count
    avg_diff = different_semantics_sim_sum / different_semantics_count

    # SADS score
    sads = avg_same - avg_diff

    return sads


def calculate_lrds_score(cosine_similarities, language_columns):
    """
    Calculate the Language Representation Difference Score (LRDS).

    LRDS = Average(Sij | lang(si) = lang(sj), semantics(si) ≠ semantics(sj)) -
           Average(Sij | lang(si) ≠ lang(sj), semantics(si) ≠ semantics(sj))

    Parameters:
    cosine_similarities (dict): Dictionary mapping language pairs to similarity matrices.
    language_columns (list): List of language identifiers.

    Returns:
    float: The LRDS score.
    """
    # For same language, different semantics
    same_language_diff_semantics_pairs = []

    # For different language, different semantics
    diff_language_diff_semantics_pairs = []

    for lang_i in language_columns:
        for lang_j in language_columns:

            if (lang_i, lang_j) not in cosine_similarities:
                continue

            sim_matrix = cosine_similarities[(lang_i, lang_j)]
            n, m = sim_matrix.shape

            # Validate dimensions
            if n != m:
                raise ValueError(
                    f"Dimension mismatch for {lang_i}-{lang_j} pair: {n}x{m}. "
                    "This suggests misalignment in language vectors."
                )

            # Get off-diagonal elements (different semantics)
            mask = ~torch.eye(n, dtype=bool, device=sim_matrix.device)
            off_diag = sim_matrix[mask]

            # Categorize based on whether languages are same or different
            if lang_i == lang_j:  # Same language
                same_language_diff_semantics_pairs.extend(off_diag.cpu().numpy())
            else:  # Different languages
                diff_language_diff_semantics_pairs.extend(off_diag.cpu().numpy())

    # Ensure we have the necessary data for calculation
    if not same_language_diff_semantics_pairs:
        raise ValueError(
            "Could not compute same-language, different-semantics similarities. "
            "Ensure that same-language similarity matrices are provided."
        )

    if not diff_language_diff_semantics_pairs:
        raise ValueError(
            "Could not compute different-language, different-semantics similarities. "
            "Ensure that cross-language similarity matrices are provided."
        )

    # Calculate averages using numpy to avoid overflow
    avg_same_lang_diff_sem = np.mean(same_language_diff_semantics_pairs)
    avg_diff_lang_diff_sem = np.mean(diff_language_diff_semantics_pairs)

    # LRDS score
    lrds = avg_same_lang_diff_sem - avg_diff_lang_diff_sem

    return lrds


def compute_similarities_by_language_pairs(activations_by_language, language_columns):
    # Normalize all vectors
    num_languages = len(language_columns)
    # Dictionary to store results
    similarities = {}

    for i, lang_i in enumerate(language_columns):
        for j in range(i, num_languages):  # Start from i to avoid duplicates
            lang_j = language_columns[j]
            key = (lang_i, lang_j)
            similarities[key] = torch.matmul(
                activations_by_language[lang_i],
                activations_by_language[lang_j].transpose(0, 1),
            )
    return similarities


def process_revision_activations(
    activations_by_language,
    language_columns,
    tokenizer,
    dataset,
    return_middle_layer=True,
    return_all_layers=False,
):
    optimized_activations_by_lang = {}
    middle_layer_activations_by_lang = {}
    all_layers_activations_by_lang = {}
    sample_debug = True

    for language_col in language_columns:
        optimized_lang_activations = []
        middle_layer_lang_activations = []
        all_layers_lang_activations = []
        lang_activations = activations_by_language[language_col]

        for i, activations in enumerate(lang_activations):
            # unpadded_len = len(tokenizer.encode(dataset[language_col][i][0]))
            layer_outputs_tuple = torch.unbind(activations, dim=0)
            """layer_outputs = [
                torch.mean(layer_representation[:, -unpadded_len:, ...], dim=1).squeeze(0)
                for layer_representation in layer_outputs_tuple
            ]"""
            layer_outputs = [
                layer_representation.squeeze(0)
                for layer_representation in layer_outputs_tuple
            ]

            if sample_debug:
                print(
                    f"The sample shape of layer_outputs: {layer_outputs[0].shape}. layer_outputs_tuple_len: {len(layer_outputs_tuple)}"
                )

            # For all layers processing (new)
            if return_all_layers:
                normalized_layer_outputs = []
                for layer_output in layer_outputs:
                    h_magnitude = torch.norm(layer_output, p=2)
                    normalized_layer_outputs.append(layer_output / h_magnitude)
                all_layers_lang_activations.append(normalized_layer_outputs)

            # For middle layer processing
            if return_middle_layer:
                middle_layer_idx = len(layer_outputs) // 2
                middle_layer_output = layer_outputs[middle_layer_idx]

                # Normalize middle layer activation
                middle_h_magnitude = torch.norm(middle_layer_output, p=2)
                normalized_middle_activation = middle_layer_output / middle_h_magnitude
                middle_layer_lang_activations.append(normalized_middle_activation)

            # For concatenated layers processing (original behavior)
            concatenated_tensor = torch.cat(layer_outputs, dim=-1)
            if sample_debug:
                print(
                    f"The sample shape of concatenated_tensor: {concatenated_tensor.shape}"
                )
                sample_debug = False

            h_magnitude = torch.norm(concatenated_tensor, p=2)
            normalized_activation = concatenated_tensor / h_magnitude
            optimized_lang_activations.append(normalized_activation)

        language_activations_tensor = torch.stack(optimized_lang_activations)
        optimized_activations_by_lang[language_col] = language_activations_tensor

        if return_middle_layer:
            middle_layer_activations_tensor = torch.stack(middle_layer_lang_activations)
            middle_layer_activations_by_lang[language_col] = (
                middle_layer_activations_tensor
            )

        if return_all_layers:
            # Stack all samples for each layer
            num_layers = len(all_layers_lang_activations[0])
            layer_tensors = []
            for layer_idx in range(num_layers):
                layer_data = [
                    sample[layer_idx] for sample in all_layers_lang_activations
                ]
                layer_tensors.append(torch.stack(layer_data))
            all_layers_activations_by_lang[language_col] = layer_tensors

    return (
        optimized_activations_by_lang,
        middle_layer_activations_by_lang,
        all_layers_activations_by_lang,
    )


def load_local_checkpoint(checkpoint_dir, tokenizer):
    """Load a model from a local checkpoint"""
    print("Initializing model from config...")
    model = LanguageModel(
        checkpoint_dir, tokenizer=tokenizer, dispatch=True, device_map="auto"
    )
    return model


def get_dataset_with_sufficient_length(dataset, tokenizer):
    valid_examples = []
    for example in dataset["test_sft"]["prompt"]:
        text = example.strip()
        inputs = tokenizer(text, return_tensors="pt", truncation=False)

        if inputs.input_ids.size(1) >= 505:  # Check if we have at least 501 tokens
            # Only keep what we need, to save memory during processing
            truncated_text = tokenizer.decode(
                inputs.input_ids[0, :505], skip_special_tokens=True
            )
            valid_examples.append(truncated_text)

    print(f"Found {len(valid_examples)} valid examples with sufficient length")
    wandb.summary["valid_examples"] = len(valid_examples)
    return valid_examples


def compute_metrics(
    model,
    dataset,
    bible_dataset,
    tokenizer,
    step,
    language_columns,
):
    (
        total_lrds_score,
        middle_layer_sads_score,
        total_sads_score,
        pairwise_sads_scores,
        pairwise_lrds_scores,
    ) = compute_sads_lrds(model, bible_dataset, tokenizer, step, language_columns)

    lang_suffix = "_" + "_".join(language_columns)

    wandb.log(
        {
            f"sads_score_middle_layer{lang_suffix}": middle_layer_sads_score,
        },
        step=step,
    )
    wandb.log(
        {
            f"lrds_score{lang_suffix}": total_lrds_score,
        },
        step=step,
    )
    wandb.log(
        {
            f"sads_score{lang_suffix}": total_sads_score,
        },
        step=step,
    )
    print(pairwise_sads_scores)
    for lang_pair, score in pairwise_sads_scores.items():
        lang1, lang2 = lang_pair
        wandb.log(
            {
                f"sads_score_{lang1}_{lang2}": score,
            },
            step=step,
        )
    for lang_pair, score in pairwise_lrds_scores.items():
        lang1, lang2 = lang_pair
        wandb.log(
            {
                f"lrds_score_{lang1}_{lang2}": score,
            },
            step=step,
        )


def compute_sads_lrds(model, dataset, tokenizer, step, language_columns):
    # you might need to cast to nnsight model
    tokenizer.padding_side = "left"
    model.eval()

    def access_mlp_activation(model):
        mlp_activation = torch.stack(
            [
                torch.mean(
                    model.gpt_neox.layers[layer_idx].mlp.act.output, dim=1
                )  # we should not need this because there is no batching and padding
                for layer_idx in range(get_layer_num(model))
            ],
            dim=0,
        )
        return mlp_activation

    def get_layer_num(model):
        return len(model.gpt_neox.layers)

    display_once = True
    if language_columns is None:
        language_columns = [
            col for col in dataset.columns if col not in ["synset_id", "definition"]
        ]

    # Dictionary to store activations by language
    activations_by_language = {}

    for language_col in language_columns:
        print(f"Processing language: {language_col}...")
        language_activations = []

        with torch.no_grad():
            # Process each row in the dataset
            total_rows = len(dataset)
            for idx in range(total_rows):
                row = dataset.iloc[idx]
                with model.trace() as tracer:
                    inpt_list = row[language_col]

                    row_activations = []
                    # Process each word in this language for the current row
                    for inpt in inpt_list:
                        if len(inpt_list) > 1:
                            print(
                                "The number of inpt is more than 1, your code is not optimized for this!"
                            )
                        prompt = f"{inpt}"
                        tokens = tokenizer.tokenize(prompt)
                        seq_len = len(tokens)
                        if display_once:
                            print(f"Input: {inpt}, Tokens: {tokens}")
                            display_once = False
                        with tracer.invoke(prompt) as invoker:
                            # Collect activations from each layer
                            layer_outputs = access_mlp_activation(model)
                            row_activations = layer_outputs
                    # if (start_idx % 50) == 0:
                    #    logger.info(f"layer_outputs shape: {row_activations_saved.shape} language {language_col}")
                    language_activations.append(row_activations.detach().save())

            language_activations_list = [act.value for act in language_activations]
            # Store the results for this language
            activations_by_language[language_col] = language_activations_list

    # Process activations for both full SADS and middle layer SADS
    (
        full_activations_by_language,
        middle_layer_activations_by_language,
        all_layers_activations_by_language,
    ) = process_revision_activations(
        activations_by_language,
        language_columns,
        tokenizer,
        dataset,
        return_middle_layer=True,
        return_all_layers=True,
    )

    # Compute similarities for full SADS (concatenated layers)
    full_cosine_similarities = compute_similarities_by_language_pairs(
        full_activations_by_language, language_columns
    )
    del full_activations_by_language

    # Compute similarities for middle layer SADS
    middle_cosine_similarities = compute_similarities_by_language_pairs(
        middle_layer_activations_by_language, language_columns
    )
    del middle_layer_activations_by_language

    # Move to CPU (ensure it's a new object)
    cpu_full_similarities = {}
    for key, tensor in full_cosine_similarities.items():
        if isinstance(tensor, torch.Tensor) and tensor.is_cuda:
            cpu_full_similarities[key] = tensor.cpu()
        else:
            cpu_full_similarities[key] = tensor

    cpu_middle_similarities = {}
    for key, tensor in middle_cosine_similarities.items():
        if isinstance(tensor, torch.Tensor) and tensor.is_cuda:
            cpu_middle_similarities[key] = tensor.cpu()
        else:
            cpu_middle_similarities[key] = tensor

    # Clean up the original GPU tensors
    del full_cosine_similarities, middle_cosine_similarities
    torch.cuda.empty_cache()
    gc.collect()

    # Calculate full SADS score
    total_sads_score = calculate_sads_score(cpu_full_similarities, language_columns)

    pairwise_sads_scores = {
        lang_pair: calculate_sads_score(cpu_full_similarities, list(lang_pair))
        for lang_pair in combinations(language_columns, 2)
        if "eng" in lang_pair
    }

    # Calculate middle layer SADS score
    middle_layer_sads_score = calculate_sads_score(
        cpu_middle_similarities, language_columns
    )

    # Calculate SADS scores for all individual layers
    layer_sads_scores = []
    num_layers = len(all_layers_activations_by_language[language_columns[0]])

    for layer_idx in range(num_layers):
        layer_activations = {
            lang: all_layers_activations_by_language[lang][layer_idx]
            for lang in language_columns
        }
        layer_similarities = compute_similarities_by_language_pairs(
            layer_activations, language_columns
        )

        # Move to CPU
        cpu_layer_similarities = {}
        for key, tensor in layer_similarities.items():
            cpu_layer_similarities[key] = tensor.cpu() if tensor.is_cuda else tensor

        layer_sads = calculate_sads_score(cpu_layer_similarities, language_columns)
        layer_sads_scores.append(layer_sads)

        del layer_similarities, cpu_layer_similarities

    # Clean up all layers activations
    del all_layers_activations_by_language

    # Find best layer
    best_layer_idx = np.argmax(layer_sads_scores)
    best_layer_score = layer_sads_scores[best_layer_idx]

    print(f"Layer SADS scores: {[f'{score:.4f}' for score in layer_sads_scores]}")
    print(f"Best layer: {best_layer_idx} with SADS score: {best_layer_score:.4f}")

    print(f"Full SADS score: {total_sads_score}")
    print(f"Middle layer SADS score: {middle_layer_sads_score}")

    # Calculate LRDS scores
    total_lrds_score = calculate_lrds_score(cpu_full_similarities, language_columns)
    pairwise_lrds_scores = {
        lang_pair: calculate_lrds_score(cpu_full_similarities, list(lang_pair))
        for lang_pair in combinations(language_columns, 2)
        if "eng" in lang_pair
    }
    return (
        total_lrds_score,
        middle_layer_sads_score,
        total_sads_score,
        pairwise_sads_scores,
        pairwise_lrds_scores,
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_type", type=str, required=True)
    parser.add_argument("--first_step", type=int, default=100)
    parser.add_argument("--last_step", type=int, default=20000)
    parser.add_argument("--wandb_id", type=str, default=None)
    parser.add_argument("--resume", type=str, default=None)

    args = parser.parse_args()
    model_type = args.model_type
    last_step = args.last_step
    wandb_id = args.wandb_id
    resume = args.resume
    first_step = args.first_step

    wandb.init(
        project="pythia_replicate_all_benchmark",
        name=model_type,
        id=wandb_id,
        resume=resume,
    )
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m")
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    dataset = load_dataset("HuggingFaceH4/ultrachat_200k")
    dataset = get_dataset_with_sufficient_length(dataset, tokenizer)
    tokenizer_name = "EleutherAI/pythia-160m"
    data_type_list = ["word", "bible"]
    data_type = data_type_list[1]
    language_columns = ["eng", "arb", "spa", "cmn", "fra"]  # this was japanese
    list_of_lang_bible = ["en", "ar", "es", "zh", "fr"]
    dataset_path = "~pythia_replicate/dataset/parallel_concepts.csv"
    bible_dataset = load_self_dataset(
        dataset_type=data_type, list_of_lang=list_of_lang_bible, column=language_columns
    )

    for step in range(first_step, last_step, 100):
        checkpoint_dir = (
            f"~pythia_replicate/hf_output/{model_type}/step={step}"
        )
        model = load_local_checkpoint(checkpoint_dir, tokenizer)
        model.eval()
        compute_metrics(
            model,
            dataset,
            bible_dataset,
            tokenizer,
            step,
            language_columns,
        )
        del model
        gc.collect()
        torch.cuda.empty_cache()
