import copy
import itertools
import logging
from collections import defaultdict
from typing import Any, Callable, Dict, List

import numpy as np
from structured_llmuq.utils import partitioned_iterator
import torch
import torch.nn.functional as F
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    StoppingCriteriaList,
)

from structured_llmuq.utils.postprocessing import StoppingCriteriaSub, truncate_output

INSTRUCT_MODELS = [
    "google/gemma-3-1b-it",
    "google/gemma-3-4b-it",
    "google/gemma-3-12b-it",
    "google/gemma-3-27b-it",
    "meta-llama/Llama-3.1-8B-Instruct",
    "Qwen/Qwen2.5-14B-Instruct",
    "gpt-4.1-mini-2025-04-14",
    "Qwen/Qwen3-4B-Instruct-2507",
    "Qwen/Qwen3-30B-A3B-Instruct-2507",
    "mistralai/Ministral-3-8B-Instruct-2512"
]

BASE_MODELS = [
    "google/gemma-3-1b-pt",
    "google/gemma-3-4b-pt",
    "google/gemma-3-12b-pt",
    "google/gemma-3-27b-pt",
    "meta-llama/Llama-3.1-8B",
    "Qwen/Qwen2.5-14B",
]


class CausalLM:
    def __init__(
        self,
        model_name: str,
        device_map: str,
        token: str,
        generation_config: Dict,
        activation_hooks: List[str] = [],
        attn_implementation: str = None,
    ):
        """
        Wrapper around the Huggingface AutoModelForCausalLM and AutoTokenizer classes for causal language modeling.
        Args:
            model_name (str): huggingface model name
            device_map (str): cpu or cuda
            token (str): access token to hugingface hub
            generation_config (Dict): generation config for the model
            activation_hooks (List[str], optional): List of activation hooks to be added. Defaults to [].
            attn_implementation (str, optional): The attention implementation to use EAGER / SDPA. Defaults to None.
        """

        load_in_8bit = False
        if model_name == "meta-llama/Llama-3.1-70B":
            load_in_8bit = True

        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype="auto",
            trust_remote_code=True,
            token=token,
            device_map=device_map,
            attn_implementation=attn_implementation,
            load_in_8bit=load_in_8bit,
        )

        self.model_type = None
        if model_name in INSTRUCT_MODELS:
            self.model_type = "instruct"
        elif model_name in BASE_MODELS:
            self.model_type = "base"
        else:
            raise ValueError(
                f"Model {model_name} not supported. Please use one of the following models: {INSTRUCT_MODELS + BASE_MODELS}"
            )

        self.tokenizer = AutoTokenizer.from_pretrained(
            model_name, trust_remote_code=True, token=token
        )
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.generation_config = self.model.generation_config
        for key, value in generation_config.items():
            if hasattr(self.generation_config, key):
                # check if value is none
                if value is None:
                    logging.warning(
                        f"Generation config {key} is None. Using default value instead {self.generation_config.__dict__[key]}."
                    )
                    continue

                setattr(self.generation_config, key, value)
            else:
                logging.warning(
                    f"Generation config {key} not found in model's generation config. Skipping."
                )
        self.generation_config.pad_token_id = self.tokenizer.eos_token_id
        self.generation_config.eos_token_id = self.tokenizer.eos_token_id

        self.activation_hooks = activation_hooks
        # Since they have vision their layer structure is different
        special_models = [
            "google/gemma-3-4b-pt",
            "google/gemma-3-12b-pt",
            "google/gemma-3-27b-pt",
            "google/gemma-3-4b-it",
            "google/gemma-3-12b-it",
            "google/gemma-3-27b-it",
        ]

        # Add hooks for activations if needed
        self.activations = defaultdict(dict)
        if activation_hooks is not None:
            layers = (
                self.model.model.language_model.layers
                if model_name in special_models
                else self.model.model.layers
            )
            for hook in activation_hooks:  # TODO: currently implemented such that in both the sampling and beam search generation hooks are activated (not ideal because it is redundant)
                if hook == "mlp":
                    logging.info(f"Adding MLP hooks")
                    # Register hooks for layers
                    # handles = []
                    for layer_idx in range(len(layers)):
                        handle = layers[layer_idx].mlp.register_forward_hook(
                            create_mlp_hook(self.activations, layer_idx)
                        )
                        # handles.append(handle)
                if hook == "mhsa":
                    logging.info(f"Adding MHSA hooks")
                    # Register hooks for layers
                    # handles = []
                    for layer_idx in range(len(layers)):
                        att = False
                        # check if attention hook also needed
                        if "attn" in activation_hooks:
                            logging.info(f"Adding Attention hooks")
                            # ensure attention implementation is eager https://github.com/huggingface/transformers/issues/33858
                            assert attn_implementation == "eager", (
                                "Attention hook requires eager attention implementation"
                            )
                            assert self.generation_config.output_attentions == True, (
                                "Attention hook requires output_attentions=True in generation_config"
                            )
                            att = True

                        handle = layers[layer_idx].self_attn.register_forward_hook(
                            create_mhsa_hook(self.activations, layer_idx, att)
                        )
                        # handles.append(handle)

                if hook == "ress":
                    logging.info(f"Adding Residual Stream hooks")
                    # Register hooks for layers
                    # handles = []
                    for layer_idx in range(len(layers)):
                        handle = layers[layer_idx].register_forward_hook(
                            create_residual_stream_hook(self.activations, layer_idx)
                        )
                        # handles.append(handle)

    def generate(
        self,
        prompt: str,
        stop_sequences: list[str] = [],
        max_num_parallel_generations: int | None = None,
    ) -> Dict[str, Any]:
        """
        Generates text from the model given a prompt and returns the log likelihood of the generated text.
        """

        if self.model.name_or_path in INSTRUCT_MODELS:
            messages = [
                {
                    "role": "system",
                    "content": "",
                },
                {"role": "user", "content": prompt},
            ]
            inputs = self.tokenizer.apply_chat_template(
                messages,
                tokenize=True,
                add_generation_prompt=True,
                return_tensors="pt",
                return_dict=True,
                enable_thinking=False,  # Switches between thinking and non-thinking modes. Default is True.
            ).to(self.model.device)

        elif self.model.name_or_path in BASE_MODELS:
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",  # , return_attention_mask=True
            ).to(self.model.device)
        else:
            raise ValueError(
                f"Model {self.model.name_or_path} not supported. Please use one of the following models: {INSTRUCT_MODELS + BASE_MODELS}"
            )

        cpu_rng_state = torch.get_rng_state()  # save states
        gpu_rng_state = torch.cuda.get_rng_state_all()

        stopping_criteria = StoppingCriteriaList(
            [
                StoppingCriteriaSub(
                    stops=stop_sequences,
                    initial_length=len(inputs["input_ids"][0]),
                    match_on="text",
                    tokenizer=self.tokenizer,
                )
            ]
        )
        generation_config_dict = self.generation_config.to_dict()
        
        max_num_parallel_generations = max_num_parallel_generations or generation_config_dict["num_return_sequences"]

        result_dicts = []
        for num_return_sequences in partitioned_iterator(generation_config_dict["num_return_sequences"], max_num_parallel_generations):
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    **generation_config_dict
                    | dict(num_return_sequences=num_return_sequences),
                    stopping_criteria=stopping_criteria,
                )  # returns DecoderOutput: transformers.generation.GenerateDecoderOnlyOutput
            # logging.info(f"Multinomial Sampling done")
            tokens = (
                outputs.sequences.detach()
            )  # Shape: (batch_size, generated_length), torch.LongTensor
            scores = outputs.scores  # tuple(torch.FloatTensor) with length max_new_tokens, shape: (batch_size, vocab_size)

            beam_indices = None
            if self.generation_config.num_beams > 1:
                beam_indices = (
                    outputs.beam_indices
                )  # Shape: (batch_size, generated_length), torch.LongTensor

            log_likelihood = (
                self.model.compute_transition_scores(
                    sequences=tokens,
                    scores=scores,
                    normalize_logits=True,
                    beam_indices=beam_indices,
                )
                .cpu()
                .numpy()
            )
            tokens = [t[len(inputs["input_ids"][0]) :].cpu() for t in tokens]
            tokens_decoded_generated = self.tokenizer.batch_decode(
                tokens, skip_special_tokens=True
            )

            # Truncate Output with stop criteria if stop sqeuences provided
            if stop_sequences is None:
                log_likelihood_truncated = log_likelihood
                tokens_decoded_generated_truncated = tokens_decoded_generated
                tokens_truncated = tokens
            else:
                log_likelihood_truncated = []
                tokens_decoded_generated_truncated = []
                tokens_truncated = []
                for t, l in zip(tokens_decoded_generated, log_likelihood):
                    out = truncate_output(t, l, self.tokenizer, stop_sequences)
                    tokens_decoded_generated_truncated.append(out["text"])
                    log_likelihood_truncated.append(out["token_scores"])
                    tokens_truncated.append(out["text_encoded"])

            result_dict = {
                "log_likelihood": log_likelihood,
                "log_likelihood_truncated": log_likelihood_truncated,
                "tokens": tokens,
                "tokens_truncated": tokens_truncated,
                "tokens_decoded_generated": tokens_decoded_generated,
                "tokens_decoded_generated_truncated": tokens_decoded_generated_truncated,
            }

            if self.generation_config.output_hidden_states is True:
                result_dict["hidden_states"] = outputs.hidden_states

            if self.activation_hooks:
                result_dict["activations"] = copy.deepcopy(self.activations)
            
            result_dicts.append(result_dict)

        torch.set_rng_state(cpu_rng_state)  # restore states
        torch.cuda.set_rng_state_all(gpu_rng_state)

        # Collate results
        result =  {
            "log_likelihood": [
                d["log_likelihood"] for d in result_dicts
            ],
            "log_likelihood_truncated" : sum((d["log_likelihood_truncated"] for d in result_dicts), []),
            "tokens": list(
                itertools.chain.from_iterable(d["tokens"] for d in result_dicts)
            ),
            "tokens_truncated": list(
                itertools.chain.from_iterable(d["tokens_truncated"] for d in result_dicts)
            ),
            "tokens_decoded_generated": list(
                itertools.chain.from_iterable(
                    d["tokens_decoded_generated"] for d in result_dicts
                )
            ),
            "tokens_decoded_generated_truncated": list(
                itertools.chain.from_iterable(
                    d["tokens_decoded_generated_truncated"] for d in result_dicts
                )
            ),  
        }
        return result
        
        

    def forward(
        self, prompt: str, completion: str, return_dist: bool = False
    ) -> Dict[str, Any]:
        """
        Given a prompt and a completion, calculates the logit & probs of completion and of first token of completion

        prompt: Question
        completion: completion to which we want to measure the probability
        """

        inputs = self.tokenizer(
            prompt,
            return_tensors="pt",  # , return_attention_mask=True
        ).to(self.model.device)
        input_ids = inputs["input_ids"].to(self.model.device)

        encoding_kwargs = {}

        completion_tokens = self.tokenizer.encode(
            completion, return_tensors="pt", add_special_tokens=False,
        ).to(self.model.device)

        concat_input_ids = torch.cat([input_ids, completion_tokens], dim=1)

        with torch.no_grad():
            outputs = self.model(
                input_ids=concat_input_ids,
            )

        logits = outputs.logits
        question_len = input_ids.shape[1]
        completion_logits = logits[:, question_len - 1 : -1, :]

        # Compute softmax to get probabilities.
        completion_probs_tensor = F.softmax(completion_logits, dim=-1)
        first_token_dist = completion_probs_tensor[:, 0].clone().detach()
        # Unsqueeze decoy_tokens to match dimensions: from [1, decoy_length] to [1, decoy_length, 1]
        completion_tokens_expanded = completion_tokens.unsqueeze(-1)
        # Gather the probability assigned to the actual decoy tokens.
        token_probs = completion_probs_tensor.gather(
            dim=-1, index=completion_tokens_expanded
        )
        token_probs = token_probs.squeeze(-1)  # Now shape: [1, decoy_length]
        # Compute the overall probability as the product of probabilities in log space.
        sentence_prob = token_probs.log().mean().exp().cpu().item()
        token_logit = (
            completion_logits[:, 0, completion_tokens[0, 0]]
            .clone()
            .detach()
            .cpu()
            .item()
        )  # First token logit
        token_prob = (
            completion_probs_tensor[:, 0, completion_tokens[0, 0]]
            .clone()
            .detach()
            .cpu()
            .item()
        )  # First prob logit

        result_dict = {
            "sentence_prob": sentence_prob,
            "token_prob": token_prob,
            "token_logit": token_logit,
            "first_token_dist": first_token_dist.cpu() if return_dist else None,
        }

        if self.generation_config.output_hidden_states == True:
            result_dict["hidden_states"] = outputs.hidden_states

        if self.activation_hooks:
            result_dict["activations"] = copy.deepcopy(self.activations)

        return result_dict


def create_pre_projection_hook(
    activations: Dict[str, Any], layer_idx: int, last_token: int
) -> Callable:
    """
    Returns a forward‑hook that saves the pre-projection
    from the module's output.

    Args:
      layer_idx: int, which layer we're saving for
    """

    def hook(module, _inputs, output):
        # output shape: (B, S, D)
        # return inputs
        selected = _inputs[0]  # tuple
        activations["v_projection"][layer_idx] = (
            selected[0, :last_token, :].detach().cpu().clone()
        )
        return output

    return hook


def create_attention_flow_hook(
    activations: Dict[str, Any], layer_idx: int, last_token: int
) -> Callable:
    """
    Returns a forward‑hook that saves the attention flow
    from the module's output.

    Args:
      layer_idx: int, which layer we're saving for
    """

    def hook(module, _inputs, output):
        hidden_states, self_attn_weights = output
        activations["attn_flow"][layer_idx] = (
            self_attn_weights[0][:, last_token - 1, :last_token].detach().cpu().clone()
        )
        return output

    return hook


def create_mlp_hook(activations: Dict[str, Any], layer_idx: int) -> Callable:
    """Hook for the MLP of the transformer layer"""

    def mlp_hook(module, input, output):
        # Check if generated_length > 1 to focus on input token pass
        if output.shape[1] > 1:
            # Save hidden state of the last token for this layer
            activations["mlp"][layer_idx] = (
                output[0, -1, :].detach().cpu().clone()
            )  # Shape: (batch_size, hidden_size)

    return mlp_hook


def create_mhsa_hook(
    activations: Dict[str, Any], layer_idx: int, att: bool
) -> Callable:
    """Hook for the multi-head self-attention of the transformer layer"""

    def mhsa_hook(module, input, output):
        # Check if generated_length > 1 to focus on input token pass
        # output is hidden_states,self_attn_weights,present_key_value
        hidden_states, self_attn_weights = output
        if hidden_states.shape[1] > 1:
            # Save hidden state of the last token for this layer
            activations["mhsa"][layer_idx] = (
                hidden_states[0, -1, :].detach().cpu()
            )  # Shape: (batch_size, hidden_size)
            if att:  # if true save attention weights
                activations["attn"][layer_idx] = (
                    self_attn_weights[0]
                    .detach()
                    .cpu()[:, :, :]
                    .clone()  # tailored to few shot 132 or 80
                )  # Shape: (batch_size, num_heads, seq_length, seq_length)

    return mhsa_hook


def create_residual_stream_hook(
    activations: Dict[str, Any], layer_idx: int
) -> Callable:
    """Hook for the residual stream of the transformer layer"""

    def residual_stream_hook(module, input, output):
        # Check if generated_length > 1 to focus on input token pass
        hidden_states, *_ = output
        if hidden_states.shape[1] > 1:
            # Save hidden state of the last token for this layer
            activations["ress"][layer_idx] = (
                hidden_states[0, -1, :].detach().cpu().clone()
            )  # Shape: (batch_size, hidden_size)

    return residual_stream_hook
