# Inspired by https://github.com/andyzoujm/representation-engineering/blob/main/repe/rep_readers.py#L30
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Dict, Optional, Tuple

import numpy as np
import torch
from jaxtyping import Float
from sklearn.decomposition import PCA

from code_demeanor.logger import logger
from code_demeanor.utils import project_onto_direction, recenter


class InputSourceType(Enum):
    HIDDEN_STATES = "hidden_states"
    ATTENTION_VALUES = "attention_values"


class RepReaderType(Enum):
    """Enum for RepReader types"""

    PCA = "pca"
    CLUSTER_MEAN = "cluster_mean"


@dataclass
class RepReaderInfo:
    """Data model for storing RepReader information."""

    directions: Dict[int, Float[torch.Tensor, "n_components hidden_size"]]
    H_train_means: Dict[int, Float[torch.Tensor, "1 hidden_size"]]
    n_components: int
    attention_directions: Optional[
        Dict[int, Float[torch.Tensor, "n_components hidden_size"]]
    ] = None

    def save(self, path: str):
        """Save the RepReaderInfo to a file."""
        torch.save(self.__dict__, path)
        logger.info("RepReaderInfo saved to", path=path)

    @classmethod
    def load(cls, path: str) -> "RepReaderInfo":
        """Load the RepReaderInfo from a file."""
        if not os.path.exists(path):
            raise FileNotFoundError(f"RepReaderInfo file {path} not found.")
        data = torch.load(path, weights_only=False)
        return cls(**data)


class RepReader(ABC):
    """Class to identify and store concept directions.

    Subclasses implement the abstract methods to identify concept directions
    for each hidden layer via strategies including PCA.

    RepReader instances are used by RepReaderPipeline to get concept scores.

    Directions can be used for downstream interventions."""

    needs_hiddens = True

    @abstractmethod
    def __init__(
        self, input_source: InputSourceType = InputSourceType.HIDDEN_STATES
    ) -> None:
        self.direction_method = None
        # self.directions are accessible via directions[layer][component_index]
        self.directions = None
        self.attention_directions = None
        self.input_source = input_source

    @abstractmethod
    def get_rep_directions(self, hidden_states, hidden_layers, **kwargs):
        """Get concept directions for each hidden layer of the model

        Args:
            hidden_states: Hidden states of the model on the training data (per layer)
            hidden_layers: Layers to consider

        Returns:
            directions: A dict mapping layers to direction arrays (n_components, hidden_size)
        """
        pass

    def transform(self, hidden_states, hidden_layers, component_index, device):
        """
        Project hidden states using internal self.H_train_means if available.
        """
        return self._transform_core(
            hidden_states, hidden_layers, component_index, device, self.H_train_means
        )

    def transform_with_custom_means(
        self, hidden_states, hidden_layers, component_index, device, H_train_means
    ):
        """
        Project hidden states using custom H_train_means passed explicitly.
        """
        return self._transform_core(
            hidden_states, hidden_layers, component_index, device, H_train_means
        )

    def _transform_core(
        self, hidden_states, hidden_layers, component_index, device, H_train_means
    ):
        """
        Shared internal transform logic.
        """
        assert component_index < self.n_components
        transformed_hidden_states = {}

        for layer in hidden_layers:
            layer_hidden_states = hidden_states[layer]

            if H_train_means is not None and layer in H_train_means:
                layer_hidden_states = recenter(
                    layer_hidden_states, mean=H_train_means[layer], device=device
                )

            H_transformed = project_onto_direction(
                layer_hidden_states, self.directions[layer][:component_index].T, device
            )
            transformed_hidden_states[layer] = H_transformed.cpu().numpy()

        return transformed_hidden_states

    def save_rep_reader_info(self, path: str) -> RepReaderInfo:
        """Save the RepReader's directions, hidden state means, and number of components."""

        if not self.directions or not self.H_train_means or self.n_components is None:
            raise ValueError(
                "RepReader has not been fitted yet. Call get_rep_directions first."
            )

        data = RepReaderInfo(
            directions=self.directions,
            attention_directions=self.attention_directions,
            H_train_means=self.H_train_means,
            n_components=self.n_components,
        )
        data.save(path)
        return data


class PCARepReader(RepReader):
    """Automated PCA RepReader that uses PCA to find directions for each layer.
    Contrary to the original PCARepReader from Representation Engineering paper,
    this one does not require the number of components to be specified.
    Instead, it uses the explained variance ratio to determine the number of components
    to keep for each layer.
    """

    def __init__(
        self,
        explained_variance_threshold=0.90,
        max_components: Optional[int] = None,
    ):
        super().__init__()
        self.explained_variance_threshold = explained_variance_threshold
        self.n_components = None  # Will be set during PCA fitting
        self.H_train_means: Dict[int, Float[torch.Tensor, "1 hidden_size"]] = {}
        self.max_components = max_components

    def get_rep_directions(self, hidden_states, hidden_layers, **kwargs):
        if self.input_source == InputSourceType.ATTENTION_VALUES:
            return self.get_rep_attention_directions(
                hidden_states, hidden_layers, **kwargs
            )
        elif self.input_source == InputSourceType.HIDDEN_STATES:
            return self.get_rep_layer_directions(hidden_states, hidden_layers, **kwargs)
        raise ValueError(
            f"Unsupported input source type: {self.input_source}. "
            "Supported types are: HIDDEN_STATES, ATTENTION_VALUES."
        )

    def get_rep_attention_directions(
        self, attention_values, attention_layers, **kwargs
    ):
        """Get concept directions for attention values of the model on the training data."""
        directions = {}
        for layer in attention_layers:
            A_train = attention_values[layer]
            A_train_mean = A_train.mean(axis=0, keepdims=True)
            self.H_train_means[layer] = A_train_mean
            A_train = recenter(A_train, mean=A_train_mean).cpu()
            A_train = np.vstack(A_train)
            pca_model = PCA(
                whiten=False,
                svd_solver="full",
                random_state=kwargs.get("random_state", 42),
            ).fit(A_train)
            cumulative_variance = np.cumsum(pca_model.explained_variance_ratio_)
            self.n_components = (
                np.argmax(cumulative_variance >= self.explained_variance_threshold) + 1
            )
            if (
                self.max_components is not None
                and self.n_components > self.max_components
            ):
                logger.warning(
                    f"Number of components {self.n_components} exceeds max_components {self.max_components}. "
                    "Limiting to max_components."
                )
                self.n_components = self.max_components
            logger.info(
                "Performing Automated PCA on attention values",
                layer=layer,
                n_components=self.n_components,
            )
            directions[layer] = pca_model.components_[: self.n_components]
        return directions

    def get_rep_layer_directions(self, hidden_states, hidden_layers, **kwargs):
        directions = {}
        for layer in hidden_layers:
            H_train: Float[torch.Tensor, "n_examples hidden_size"] = hidden_states[
                layer
            ]
            H_train_mean = H_train.mean(axis=0, keepdims=True)
            self.H_train_means[layer] = H_train_mean
            H_train = recenter(H_train, mean=H_train_mean).cpu()
            H_train = np.vstack(H_train)

            pca_model = PCA(
                whiten=False,
                svd_solver="full",
                random_state=kwargs.get("random_state", 42),
            ).fit(H_train)

            # Determine number of components based on explained variance ratio
            cumulative_variance = np.cumsum(pca_model.explained_variance_ratio_)
            self.n_components = (
                np.argmax(cumulative_variance >= self.explained_variance_threshold) + 1
            )
            if (
                self.max_components is not None
                and self.n_components > self.max_components
            ):
                logger.warning(
                    f"Number of components {self.n_components} exceeds max_components {self.max_components}. "
                    "Limiting to max_components."
                )
                self.n_components = self.max_components

            logger.info(
                "Performing Automated PCA",
                layer=layer,
                n_components=self.n_components,
            )
            directions[layer] = pca_model.components_[: self.n_components]

        return directions


if __name__ == "__main__":
    # Example usage
    pca_reader = PCARepReader(explained_variance_threshold=0.90, max_components=10)

    hidden_states = {
        0: torch.randn(100, 768),  # Layer 0 Shape: (100 examples, hidden_size)
        1: torch.randn(100, 768),  # Layer 1 Shape: (100 examples, hidden_size)
    }

    hidden_layers = [0, 1]

    # Get PCA directions
    directions = pca_reader.get_rep_directions(hidden_states, hidden_layers)
    pca_reader.directions = directions
    # Transform hidden states using the PCA directions
    transformed_states = pca_reader.transform(
        hidden_states, [0], component_index=4, device="cpu"
    )
    logger.info("Transformed hidden states", transformed_states=transformed_states)
    logger.info("Storing RepReader information")
    pca_reader.save_rep_reader_info("rep_reader_info.pth")
    # Load the saved RepReader information
    loaded_data = RepReaderInfo.load("rep_reader_info.pth")

    logger.info("Loaded RepReader information", loaded_data=loaded_data)

    pca_reader.transform_with_custom_means(
        hidden_states,
        [0],
        component_index=4,
        device="cpu",
        H_train_means=loaded_data.H_train_means,
    )
    # Remove the saved file after use
    import os

    os.remove("rep_reader_info.pth")
    logger.info("Temporary file removed")

    # How to access Attention Values
    from transformers import AutoModel, AutoTokenizer

    model_name = "gpt2"
    model = AutoModel.from_pretrained(model_name, output_attentions=True)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token  # Set pad_token to eos_token

    input_text = ["This is a test input for the model.", "Another input for attention."]
    inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
    outputs = model(**inputs, output_attentions=True)

    attention_values: Tuple[
        Float[torch.Tensor, "batch_size num_heads seq_length seq_length"], ...
    ] = outputs.attentions  # This will be a tuple of attention tensors per layer
    attention_layers = list(range(len(attention_values)))
    attention_L1H3 = attention_values[0][:, 3, :, :]  # Example: Layer 0, Head 3
    attention_L1H3 = attention_L1H3.detach().cpu().numpy()  # Convert to numpy for PCA
    attention_values = {
        0: attention_L1H3
    }  # Simulating a single layer with attention values
    pca_reader = PCARepReader(explained_variance_threshold=0.90, max_components=10)
    directions = pca_reader.get_rep_directions(attention_values, [0])
    pca_reader.directions = directions
    transformed_attention = pca_reader.transform(
        attention_values, [0], component_index=4, device="cpu"
    )
    logger.info(
        "Transformed attention values", transformed_attention=transformed_attention
    )
