# Inspired by https://github.com/andyzoujm/representation-engineering/blob/main/
import os
from typing import List, Optional, Union

import numpy as np
import torch
from jaxtyping import Float

from code_demeanor.logger import logger
from code_demeanor.reading.readers import (
    PCARepReader,
    RepReader,
    RepReaderType,
    project_onto_direction,
)
from code_demeanor.utils import get_device, send_to_device, set_seed
from transformers import AutoModel, AutoTokenizer, Pipeline, pipeline
from transformers.pipelines import PIPELINE_REGISTRY


def direction_finder_factory(direction_method: str, **kwargs) -> RepReader:
    """Factory function to create a RepReader based on the specified method."""
    if direction_method == RepReaderType.PCA:
        return PCARepReader(**kwargs)

    raise ValueError(f"Unknown direction method: {direction_method}")


class RepReadingPipeline(Pipeline):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def _get_hidden_states(
        self,
        outputs,
        rep_token: Union[str, int] = -1,
        hidden_layers: Union[List[int], int] = -1,
        which_hidden_states: Optional[str] = None,
    ):

        if hasattr(outputs, "encoder_hidden_states") and hasattr(
            outputs, "decoder_hidden_states"
        ):
            outputs["hidden_states"] = outputs[f"{which_hidden_states}_hidden_states"]

        hidden_states_layers = {}
        for layer in hidden_layers:
            hidden_states = outputs["hidden_states"][layer]
            hidden_states = hidden_states[:, rep_token, :].detach()
            if hidden_states.dtype == torch.bfloat16:
                hidden_states = hidden_states.float()
            hidden_states_layers[layer] = hidden_states.detach()

        return hidden_states_layers

    def _sanitize_parameters(
        self,
        rep_reader: RepReader = None,
        rep_token: Union[str, int] = -1,
        hidden_layers: Union[List[int], int] = -1,
        component_index: int = 0,
        which_hidden_states: Optional[str] = None,
        **tokenizer_kwargs,
    ):
        preprocess_params = tokenizer_kwargs
        forward_params = {}
        postprocess_params = {}

        forward_params["rep_token"] = rep_token

        if not isinstance(hidden_layers, list):
            hidden_layers = [hidden_layers]

        assert rep_reader is None or len(rep_reader.directions) == len(
            hidden_layers
        ), f"expect total rep_reader directions ({len(rep_reader.directions)})== total hidden_layers ({len(hidden_layers)})"
        forward_params["rep_reader"] = rep_reader
        forward_params["hidden_layers"] = hidden_layers
        forward_params["component_index"] = component_index
        forward_params["which_hidden_states"] = which_hidden_states

        return preprocess_params, forward_params, postprocess_params

    def preprocess(
        self, inputs: Union[str, List[str], List[List[str]]], **tokenizer_kwargs
    ):

        if self.image_processor:
            return self.image_processor(
                inputs, add_end_of_utterance_token=False, return_tensors="pt"
            )
        return self.tokenizer(inputs, return_tensors=self.framework, **tokenizer_kwargs)

    def postprocess(self, outputs):
        return outputs

    def _forward(
        self,
        model_inputs,
        rep_token,
        hidden_layers,
        rep_reader: RepReader = None,
        component_index=0,
        which_hidden_states=None,
        pad_token_id=None,
    ):
        """
        Args:
        - which_hidden_states (str): Specifies which part of the model (encoder, decoder, or both) to compute the hidden states from.
                        It's applicable only for encoder-decoder models. Valid values: 'encoder', 'decoder'.
        """
        # get model hidden states and optionally transform them with a RepReader
        with torch.no_grad():
            if hasattr(self.model, "encoder") and hasattr(self.model, "decoder"):
                decoder_start_token = [self.tokenizer.pad_token] * model_inputs[
                    "input_ids"
                ].size(0)
                decoder_input = self.tokenizer(
                    decoder_start_token, return_tensors="pt"
                ).input_ids
                model_inputs["decoder_input_ids"] = decoder_input
            outputs = self.model(**model_inputs, output_hidden_states=True)
        hidden_states = self._get_hidden_states(
            outputs, rep_token, hidden_layers, which_hidden_states
        )

        if rep_reader is None:
            return hidden_states

        return rep_reader.transform(hidden_states, hidden_layers, component_index)

    def _batched_string_to_hiddens(
        self,
        train_inputs,
        rep_token,
        hidden_layers,
        batch_size,
        which_hidden_states,
        **tokenizer_args,
    ):
        # Wrapper method to get a dictionary hidden states from a list of strings
        hidden_states_outputs = self(
            train_inputs,
            rep_token=rep_token,
            hidden_layers=hidden_layers,
            batch_size=batch_size,
            rep_reader=None,
            which_hidden_states=which_hidden_states,
            **tokenizer_args,
        )
        hidden_states = {layer: [] for layer in hidden_layers}
        for hidden_states_batch in hidden_states_outputs:
            for layer in hidden_states_batch:
                hidden_states[layer].extend(hidden_states_batch[layer])
        return {k: np.vstack(v) for k, v in hidden_states.items()}

    def _validate_params(self, n_difference: int, direction_method: RepReaderType):
        # validate params for get_directions
        if direction_method == RepReaderType.CLUSTER_MEAN:
            assert n_difference == 1, "n_difference must be 1 for clustermean"

    def get_directions(
        self,
        train_inputs: Union[str, List[str], List[List[str]]],
        rep_token: Union[str, int] = -1,
        hidden_layers: Union[str, int] = -1,
        n_difference: int = 1,
        batch_size: int = 8,
        train_labels: List[int] = None,
        direction_method: str = RepReaderType.PCA,
        direction_finder_kwargs: dict = {},
        which_hidden_states: Optional[str] = None,
        **tokenizer_args,
    ):
        """Train a RepReader on the training data.
        Args:
            batch_size: batch size to use when getting hidden states
            direction_method: string specifying the RepReader strategy for finding directions
            direction_finder_kwargs: kwargs to pass to RepReader constructor
        """

        if not isinstance(hidden_layers, list):
            assert isinstance(hidden_layers, int)
            hidden_layers = [hidden_layers]
        if isinstance(train_labels, list) and isinstance(train_labels[0], list):
            logger.warning(
                """We are expecting train_labels to be a list of integers, not a list of lists.
                           Otherwise, we assume that each entry in train_labels corresponds to a pair of inputs.
                           [[0, 1], [1, 0], ...]
                           """
            )
            assert (
                len(train_inputs) % 2 == 0
            ), "train_inputs must be even if train_labels is a list of lists"
            assert (
                len(train_labels[0]) == 2
            ), "train_labels must be a list of lists with length 2"
            if n_difference != 1:
                logger.warning(
                    "n_difference should probably be more than 1 if train_labels is a list of lists"
                )

        self._validate_params(n_difference, direction_method)

        direction_finder: RepReader = direction_finder_factory(
            direction_method, **direction_finder_kwargs
        )

        # if relevant, get the hidden state data for training set
        hidden_states = None
        relative_hidden_states = None
        # if direction_finder.needs_hiddens:
        #     # get raw hidden states for the train inputs
        hidden_states = self._batched_string_to_hiddens(
            train_inputs,
            rep_token,
            hidden_layers,
            batch_size,
            which_hidden_states,
            **tokenizer_args,
        )

        #     # get differences between pairs
        #     relative_hidden_states = {k: np.copy(v) for k, v in hidden_states.items()}
        #     for layer in hidden_layers:
        #         for _ in range(n_difference):
        #             relative_hidden_states[layer] = (
        #                 relative_hidden_states[layer][::2]
        #                 - relative_hidden_states[layer][1::2]
        #             )

        # Obtains the directions.
        direction_finder.directions = direction_finder.get_rep_directions(
            hidden_states,  # if n_difference == 1 else relative_hidden_states,
            hidden_layers,
            train_choices=train_labels,
        )
        for layer in direction_finder.directions:
            if type(direction_finder.directions[layer]) == np.ndarray:
                direction_finder.directions[layer] = direction_finder.directions[
                    layer
                ].astype(np.float32)

        return direction_finder


def repe_pipeline_registry():
    PIPELINE_REGISTRY.register_pipeline(
        "rep-reading", pipeline_class=RepReadingPipeline, pt_model=AutoModel
    )


if __name__ == "__main__":
    # Example usage
    set_seed(42)
    repe_pipeline_registry()
    logger.info("RepReadingPipeline registered successfully.")
    DEVICE = get_device()
    breakpoint()  # For debugging purposes
    model = AutoModel.from_pretrained("gpt2")
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    rep_reading_pipeline = pipeline(
        "rep-reading", model=model, tokenizer=tokenizer, device=DEVICE
    )
