import argparse
import torch
import os
from nesim.utils.json_stuff import load_json_as_dict
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from nesim.utils.grid_size import find_rectangle_dimensions
import matplotlib.colors as mcolors
from nesim.utils.figure.figure_1 import get_model_and_tokenizer, CategorySelectivity, obtain_hook_outputs
import matplotlib as mpl
from sklearn.model_selection import train_test_split
from einops import rearrange, reduce
from PIL import Image
import numpy as np

def hex_to_rgb(hex_color):
    """
    Convert a hex color string (e.g., '#C0392B') to an RGB tuple.
    """
    hex_color = hex_color.lstrip('#')
    # Convert to (R, G, B)
    rgb_color = tuple(int(hex_color[i:i+2], 16) for i in (0, 2, 4))
    return np.array(rgb_color, dtype=np.float32)

def apply_masks_with_colors(image, masks, colors, background = "white"):
    """
    Apply the given masks to the image, coloring them with the corresponding hex colors,
    and adjusting saturation based on image intensity.
    """
    assert background in ["black", "white", "gray"]

    if background == "white":
        # Initialize a blank color image (3 channels) filled with white
        color_image = np.ones((image.shape[0], image.shape[1], 3), dtype=np.float32) * 255
        multiplier = 255
    elif background == "black":
        color_image = np.zeros((image.shape[0], image.shape[1], 3), dtype=np.float32)
        multiplier = 0
    elif background == "gray":
        color_image = np.ones((image.shape[0], image.shape[1], 3), dtype=np.float32) * 127.5
        multiplier = 127.5
    
    # Normalize the grayscale image to range [0, 1]
    normalized_image = image / image.max()
    
    for mask, hex_color in zip(masks, colors):
        # Convert hex color string to BGR array
        rgb_color = hex_to_rgb(hex_color)
        # Create the colored region by mixing the RGB color with white based on the image intensity
        color_region = ((rgb_color/255 * normalized_image[:, :, None]))*255 + (1 - normalized_image[:, :, None]) * multiplier
        
        # Apply the mask: where mask is 1, use the colored region
        color_image = np.where(mask[:, :, None], color_region, color_image)
    
    # Convert back to uint8 for proper image saving/displaying
    color_image = np.clip(color_image, 0, 255).astype(np.uint8)
    
    return Image.fromarray(color_image)

mpl.rcParams["pdf.fonttype"] = 42
mpl.rcParams["ps.fonttype"] = 42

parser = argparse.ArgumentParser(description="Generate assets for figure 1")
parser.add_argument("--baseline-checkpoint-filename", type=str, required=True)
parser.add_argument("--topo-checkpoint-filename", type=str, required=True)
parser.add_argument("--dataset-filename", type=str, required=True)
parser.add_argument("--baseline-output-filename", type=str, required=True)
parser.add_argument("--topo-output-filename", type=str, required=True)
parser.add_argument("--layer-name", type=str, required=True, default="transformer.h.10.mlp.c_fc")
parser.add_argument("--device", type=str, default="cuda:0")

args = parser.parse_args()

assert os.path.exists(args.topo_checkpoint_filename)
assert os.path.exists(args.baseline_checkpoint_filename)
assert os.path.exists(args.dataset_filename)

data = load_json_as_dict(args.dataset_filename)
dataset = load_json_as_dict(args.dataset_filename)["dataset"]
colors = load_json_as_dict(args.dataset_filename)["colors"]

model, tokenizer = get_model_and_tokenizer(name="EleutherAI/gpt-neo-125m")

checkpoint_filenames = [args.topo_checkpoint_filename, args.baseline_checkpoint_filename]
output_filenames = [args.topo_output_filename, args.baseline_output_filename]

# Split dataset into two halves for each category
train_dataset = {}
val_dataset = {}
for category, items in dataset.items():
    train_items, val_items = train_test_split(items, test_size=0.5, random_state=42)
    train_dataset[category] = train_items
    val_dataset[category] = val_items

for checkpoint_filename, output_filename in zip(checkpoint_filenames, output_filenames):
    model.load_state_dict(
        torch.load(checkpoint_filename, map_location=torch.device("cpu"), weights_only=True)
    )
    model.to(args.device)

    hook_dict = obtain_hook_outputs(
        model,
        layer_names=[args.layer_name],
        dataset=dataset,
        tokenizer=tokenizer,
        device=args.device,
    )

    c = CategorySelectivity(dataset=dataset, hook_outputs=hook_dict)

    top_activating_categories_for_each_neuron = []

    for neuron_idx in tqdm(c.valid_neuron_indices):
        activations_for_idx = torch.zeros(len(c.categories))

        for class_idx, target_class in enumerate(c.categories):
            activation = c.konkle(
                neuron_idx=neuron_idx,
                target_class=target_class,
                other_classes=[x for x in dataset.keys() if x != target_class],
                layer_name=args.layer_name,
            )
            activations_for_idx[class_idx] = activation

        top_activating_categories_for_each_neuron.append(torch.argmax(activations_for_idx).item())

    indices_for_each_proposed_category_region = [
        np.where(top_activating_categories_for_each_neuron == category)[0]
        for category in np.unique(top_activating_categories_for_each_neuron)
    ]

    # Validation on the second half of the dataset
    hook_dict_val = obtain_hook_outputs(model, layer_names=[args.layer_name], dataset=val_dataset, tokenizer=tokenizer, device=args.device)
    
    """
    For a given proposed region, I want to calculate the:
    mean activation for the proposed category / mean activation of all other categories

    hook_dict_val[args.layer_name] is a dictionary containing different categories as keys
    each value in this dictionary is a list of tensors
    each of this tensor is of shape: (batch, seq, emb)
    """
    category_names = list(val_dataset.keys())

    all_alphas = torch.zeros(3072)

    for category_idx, proposed_category_name in zip(range(len(category_names)), category_names):
        proposed_roi_neuron_indices = indices_for_each_proposed_category_region[category_idx]
        
        """
        hook_dict_val[args.layer_name][category_name] is a list of tensors, each of shape (batch, seq, emb)
        sequence length (seq) might be different for each item in the list
        """
        all_outputs_for_proposed_category = [
            rearrange(x, "batch sequence emb -> (batch sequence) emb")[:, proposed_roi_neuron_indices]
            for x in hook_dict_val[args.layer_name][proposed_category_name]
        ]
        all_outputs_for_proposed_category = torch.cat(
            all_outputs_for_proposed_category,
            dim= 0
        )
        ## now compute all outputs for the other categories

        non_proposed_categories = [x for x in category_names if x != proposed_category_name]
        absolute_activations_for_non_proposed_categories = []

        for non_proposed_category in non_proposed_categories:
            all_outputs_for_non_proposed_category = [
                rearrange(x, "batch sequence emb -> (batch sequence) emb")[:, proposed_roi_neuron_indices]
                for x in hook_dict_val[args.layer_name][non_proposed_category]
            ]
            all_outputs_for_non_proposed_category = torch.cat(
                all_outputs_for_non_proposed_category,
                dim= 0
            )
            absolute_activations_for_non_proposed_categories.append(all_outputs_for_non_proposed_category.abs())
        
        absolute_activations_for_non_proposed_categories= torch.cat(
            absolute_activations_for_non_proposed_categories,
            dim = 0
        )

        mean_absolute_activation_for_non_proposed_categories = reduce(
            absolute_activations_for_non_proposed_categories,
            "sequence emb -> emb",
            reduction="mean"
        )

        mean_absolute_activation_for_proposed_category = reduce(
            all_outputs_for_proposed_category.abs(),
            "sequence emb -> emb",
            reduction="mean"
        )

        alphas_for_each_neuro_in_proposed_roi = mean_absolute_activation_for_proposed_category / mean_absolute_activation_for_non_proposed_categories

        assert alphas_for_each_neuro_in_proposed_roi.ndim == 1
        ## the selectivity values for each neuron must have the same number of elements as the number of neurons in the proposed roi
        assert alphas_for_each_neuro_in_proposed_roi.shape[0] == proposed_roi_neuron_indices.shape[0]

        all_alphas[proposed_roi_neuron_indices] = alphas_for_each_neuro_in_proposed_roi.cpu()
        
    ### Need to fill this up ...

    fontsize = 18
    size = find_rectangle_dimensions(area=all_alphas.shape[0])
    values = all_alphas.reshape(size.height, size.width)
    category_colors_hex = list(colors.values())
    masks = []

    for category_idx in range(len(category_names)):
        proposed_roi_neuron_indices = indices_for_each_proposed_category_region[category_idx]
        mask = torch.zeros(3072)
        mask[proposed_roi_neuron_indices] = 1.
        mask = mask.reshape(size.height, size.width)

        masks.append(np.array(mask))

    
    image = apply_masks_with_colors(
        image = np.array(values),
        masks=masks,
        colors=category_colors_hex
    )
    image.save(output_filename)
    print(f"Saved: {output_filename}")