import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, AutoConfig, GenerationConfig, AutoModelForCausalLM, PreTrainedModel, CONFIG_MAPPING, MODEL_FOR_CAUSAL_LM_MAPPING
from transformers.modeling_outputs import CausalLMOutputWithPast
from os.path import isdir, exists, join
from dataclasses import dataclass
from typing import Optional
from safetensors.torch import load_file
from safetensors import safe_open
from collections import defaultdict
import json
from data.utils import IGNORE_INDEX

logger = logging.getLogger("model")

class DataSplitClassifier(nn.Module):
    
    activation_map = {
        "relu": nn.ReLU,
        "gelu": nn.GELU,
        "tanh": nn.Tanh,
        "id": nn.Identity
    }

    def __init__(self, input_dim, output_dim, hidden_size=100, num_hidden_layers=1, activation_str="id", bias=False):
        super().__init__()
        if activation_str not in DataSplitClassifier.activation_map:
            raise RuntimeError(f"Activation string {activation_str} not supported. Must be one of {DataSplitClassifier.activation_map.keys()}")
        self.activation_str = activation_str
        act = self.get_activation_cls()

        if num_hidden_layers > 1 and self.activation_str == "id":
            raise RuntimeError("Trying to set more than 1 hidden layer with identity activation is pointless")

        layers = []
        # Down Proj layer
        layers.append(nn.Linear(input_dim, hidden_size, bias=bias))
        layers.append(act())

        # Hidden layers
        for _ in range(num_hidden_layers-1):
            layers.append(nn.Linear(hidden_size, hidden_size, bias=bias))
            layers.append(act())

        # Output layer
        layers.append(nn.Linear(hidden_size, output_dim, bias=False))
        self.proj = nn.Sequential(*layers)

        logger.info(f"Initialized classifier head with {num_hidden_layers} hidden layers, hidden size {hidden_size}, and activation of type {self.get_activation_cls()}")

    def get_activation_cls(self):
        return DataSplitClassifier.activation_map[self.activation_str]
    

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        hidden_states: (batch_size, seq_len, hidden_size)
        returns: (batch_size, seq_len, vocab_size)
        """
        batch_size, seq_len, hidden_size = hidden_states.shape

        # Flatten batch and seq_len, apply the module, then reshape back
        out = self.proj(hidden_states.view(-1, hidden_size))
        out = out.view(batch_size, seq_len, -1)  # (batch_size, seq_len, vocab_size)
        return out

@dataclass
class T3CausalLMOutputWithPast(CausalLMOutputWithPast):
    classifier_logits: Optional[torch.FloatTensor] = None
    base_logits: Optional[torch.FloatTensor] = None

class ClassifierGuidedCausalLMConfig(PretrainedConfig):
    model_type = "classifier_guided_causal_lm"

    def __init__(self, guidance_kwargs=None, pooling="last", pool_temp=None, extraction_layer=-1, guidance_scale=1, base_temp=1, **kwargs):
        super().__init__(**kwargs)
        self.base_config_dict = kwargs.copy()
        self.guidance = guidance_kwargs or {}
        self.pooling = pooling
        self.pool_temp = pool_temp
        self.extraction_layer = extraction_layer
        self.base_temp = base_temp
        self.guidance_scale = guidance_scale

    def to_base(self):
        base_config_dict = self.base_config_dict.copy()
        base_model_type = base_config_dict.pop("model_type")
        return AutoConfig.for_model(base_model_type, **base_config_dict)


def _last_pool(hidden_states):
    return hidden_states

def _mean_pool(hidden_states):
    _, seq_len, _ = hidden_states.shape # batch x seq_len x hidden_size
    return hidden_states.cumsum(dim=1) / torch.arange(1,seq_len+1, device=hidden_states.device).view(1,-1,1)

def _attn_pool(hidden_states, attn_weights):
    # Average over heads
    attn_mean = attn_weights.mean(dim=1)  # (batch, seq_len, seq_len), attn_mean[b, i, j] corresponds to query i to key j for seq b in batch
    return torch.bmm(attn_mean, hidden_states) # (batch, seq_len, hidden_dim)

class ClassifierGuidedCausalLM(PreTrainedModel):
    """
    Wraps a pretrained CausalLM, freezing all base model weights,
    and adds a DataSplitDiscriminator to produce a logit adjustment vector.

    Training: Only the classifier is updated, with loss based on retain/forget classification.
    Inference: Base model logits + adjustment logits are used for token prediction.
    """

    config_class = ClassifierGuidedCausalLMConfig
    pooling_fn_dict = {
        "last": _last_pool,
        "mean": _mean_pool,
        "attn": _attn_pool
    }

    def __init__(self, config: ClassifierGuidedCausalLMConfig, base_lm=None):
        super().__init__(config)

        if base_lm is None:
            base_config = config.to_base()
            self.base_lm = AutoModelForCausalLM.from_config(base_config)
        else:
            self.base_lm = base_lm

        if not hasattr(self.base_lm, "generation_config"):
            logger.warning(f"Could not find base model generation config, resorting to default")
            self.base_lm.generation_config = GenerationConfig()
        
        self.generation_config = self.base_lm.generation_config
        
        # Freeze base model and set to eval
        self.base_lm.eval()
        for p in self.base_lm.parameters():
            p.requires_grad_(False)

        classifier_args = config.guidance.copy()

        base_device = next(self.base_lm.parameters()).device
        base_dtype = next(self.base_lm.parameters()).dtype

        self.guidance_head = DataSplitClassifier(
            input_dim=self.config.hidden_size,
            output_dim=self.config.vocab_size,
            **classifier_args
        ).to(device=base_device, dtype=base_dtype)

        assert config.pooling in ClassifierGuidedCausalLM.pooling_fn_dict.keys(), f"Pooling function string {config.pooling} not recognized"
        self.pooling_fn = ClassifierGuidedCausalLM.pooling_fn_dict[config.pooling]
        self.pooling_fn_name = config.pooling
        self.pool_temp = config.pool_temp

        self.base_lm.config.output_hidden_states=True
        self.base_lm.config.output_attentions = self.pooling_fn_name == "attn"

        self.extraction_layer = config.extraction_layer
        self.guidance_scale = config.guidance_scale
        self.base_temp = config.base_temp

        # Set attn support
        base_lm_attn_implementation = self.base_lm.__class__._autoset_attn_implementation(
            self.base_lm.config,
            torch_dtype=self.base_lm.config.torch_dtype,
            device_map=None
        )._attn_implementation
        self.config.attn_implementation = base_lm_attn_implementation

        self.lm_loss = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
        logger.info(
            f"Initialized a ClassifierGuidedCausalLM model:\n"
            f"base_lm: {self.base_lm.model.config._name_or_path}\n"
            f"pooling: {config.pooling}\n"
            f"pool_temp: {config.pool_temp}\n"
            f"extraction layer: {self.extraction_layer}\n"
        )

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str, *args, **kwargs):
        assert isdir(pretrained_model_name_or_path), "Tried to load ClassifierGuidedCausalLM but did not pass a valid path to saved model"
        
        try:
            config_path = join(pretrained_model_name_or_path,"config.json")
            logger.info(f"Found saved config at {config_path}")
            with open(config_path, "r") as f:
                config_dict = json.load(f)

            guidance_kwargs = config_dict.pop("guidance")
            base_kwargs = config_dict.pop("base_config_dict")
            pooling = config_dict["pooling"]
            pool_temp = config_dict["pool_temp"]
            guidance_scale = config_dict["guidance_scale"]
            base_temp = config_dict.get("base_temp", 1.0)
            extraction_layer = config_dict["extraction_layer"]
            config = ClassifierGuidedCausalLMConfig(
                guidance_kwargs=guidance_kwargs,
                pooling=pooling,
                pool_temp=pool_temp,
                extraction_layer=extraction_layer,
                guidance_scale=guidance_scale,
                base_temp=base_temp,
                **base_kwargs
            )
        except Exception as e:
            print(f"Failed to load config.json from {pretrained_model_name_or_path} due to error {e}")

        try:
            model_path_single = join(pretrained_model_name_or_path,"model.safetensors")
            if exists(model_path_single):
                logger.info(f"Loading model from single file {model_path_single}")
                state_dict = load_file(model_path_single)
            else:
                logger.info(f"Couldn't find single model.safetensors file at path {pretrained_model_name_or_path}. Attempting to look for sharded files.")
                state_dict = defaultdict(dict)
                index_file = join(pretrained_model_name_or_path, "model.safetensors.index.json")
                if exists(index_file):
                    with open(index_file, "r") as f:
                        index = json.load(f)
                    for key, fname in index["weight_map"].items():
                        shard_path = join(pretrained_model_name_or_path, fname)
                        with safe_open(shard_path, framework="pt", device="cpu") as f:
                            state_dict[key] = f.get_tensor(key)
                else:
                    raise FileNotFoundError(f"No safetensors file found in {pretrained_model_name_or_path}")

        except Exception as e:
            logger.error(f"Failed to load safetensors, trying pytorch bin files. Error {e}")
            state_dict = torch.load(join(pretrained_model_name_or_path,"pytorch_model.bin"), map_location="cpu")
        
        base_config = config.to_base()
        kwargs.pop("config", None)

        base_lm = AutoModelForCausalLM.from_pretrained(
            base_config._name_or_path,
            *args,
            config=base_config,
            **kwargs
        )

        base_lm_state_dict = {k.replace("base_lm.", ""): v for k,v in state_dict.items() if k.startswith("base_lm.")}
        incomp = base_lm.load_state_dict(base_lm_state_dict, strict=False)
        if incomp.missing_keys or incomp.unexpected_keys:
            logger.warning(
                f"Some weights of the base_lm could not be loaded.\n"
                f"Missing keys: {incomp.missing_keys}\n"
                f"Unexpected keys: {incomp.unexpected_keys}"
            )

            for missing_key in incomp.missing_keys:
                missing_weight = base_lm.state_dict()[missing_key]
                for n,p in base_lm.named_parameters():
                    if n != missing_key and p.data_ptr() == missing_weight.data_ptr():
                        logger.info(f"Missing key {missing_key} is tied to {n}. If {n} is loaded this will be fixed by tie_weights().")

        base_lm.tie_weights()

        model = cls(config, base_lm=base_lm)
        guidance_state_dict = {k.replace("guidance_head.", ""): v for k,v in state_dict.items() if k.startswith("guidance_head.")}
        model.guidance_head.load_state_dict(guidance_state_dict)
        return model
    
    @classmethod
    def from_pretrained_base(cls, pretrained_model_name_or_path: str, *args, guidance_kwargs=None, pooling="last", pool_temp=None, extraction_layer=-1, guidance_scale=1, base_temp=1, **kwargs):
        base_lm = super(ClassifierGuidedCausalLM, cls).from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
        return cls.from_pretained_base_obj(
            base_lm=base_lm,
            guidance_kwargs=guidance_kwargs,
            pooling=pooling,
            pool_temp=pool_temp,
            extraction_layer=extraction_layer,
            guidance_scale=guidance_scale,
            base_temp=base_temp
        )

    @classmethod
    def from_pretained_base_obj(cls, base_lm, guidance_kwargs=None, pooling="last", pool_temp=None, extraction_layer=-1, guidance_scale=1, base_temp=1):
        config = ClassifierGuidedCausalLMConfig(
            guidance_kwargs=guidance_kwargs,
            pooling=pooling,
            pool_temp=pool_temp,
            extraction_layer=extraction_layer,
            guidance_scale=guidance_scale,
            base_temp=base_temp,
            **base_lm.config.to_dict())
        return cls(config, base_lm=base_lm)

    def can_generate(self):
        return True
    
    # Delegate prepare_inputs_for_generation so HF generate() works efficiently
    def prepare_inputs_for_generation(self, *args, **kwargs):
        # Explicitly cast past_key_values -- for some reason this causes an issue
        model_dtype = None
        for p in self.base_lm.parameters():
            if p.is_floating_point():
                model_dtype = p.dtype
        if model_dtype is None:
            raise ValueError("Could not determine model dtype from its parameters")
        
        model_inputs = self.base_lm.prepare_inputs_for_generation(*args, **kwargs)
        if "past_key_values" in model_inputs and model_inputs["past_key_values"] is not None:
            new_past_key_values = []
            for layer_cache in model_inputs["past_key_values"]:
                new_key = layer_cache[0].to(model_dtype)
                new_value = layer_cache[1].to(model_dtype)
                new_past_key_values.append((new_key,new_value))

            model_inputs["past_key_values"] = tuple(new_past_key_values)
        return model_inputs

    def _guide_logits(self, base_logits, classifier_logits):
        """
        base_logits: (batch, seq_len, vocab)
        classifier_logits: (batch, seq_len, vocab)
        """
        
        base_log_probs = F.log_softmax(base_logits, dim=2) # batch x seq_len x vocab
        
        # Consider clipping this for stability
        classifier_log_probs = F.logsigmoid(classifier_logits)

        return base_log_probs/self.base_temp + self.guidance_scale*classifier_log_probs

    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        labels = kwargs.pop("labels", None)
        kwargs["output_hidden_states"] = True
        kwargs["output_attentions"] = self.pooling_fn_name == "attn"
        base_outputs = self.base_lm(input_ids=input_ids, attention_mask=attention_mask, **kwargs)

        # Adjust logits using the classifier guidance
        extracted_states = base_outputs.hidden_states[self.extraction_layer]  # (batch, seq_len, hidden_size)

        if self.pooling_fn_name == "attn":
            attn_weights = self.get_attn_weights(base_outputs)
            pooled_states = self.pooling_fn(extracted_states, attn_weights)
        else:
            pooled_states = self.pooling_fn(extracted_states)

        classifier_logits = self.guidance_head(pooled_states)  # (batch, seq_len, vocab)
        with torch.no_grad():
            guided_logits = self._guide_logits(base_outputs.logits, classifier_logits)

        outputs = T3CausalLMOutputWithPast(
            loss=None,
            logits=guided_logits,
            past_key_values = base_outputs.past_key_values,
            hidden_states = base_outputs.hidden_states,
            attentions = base_outputs.attentions,
            base_logits = base_outputs.logits,
            classifier_logits = classifier_logits
        )
        return outputs
    
    def train(self, mode: bool = True):
        super().train(mode)
        self.guidance_head.train(mode)
        # always keep base model in eval mode
        self.base_lm.eval()
        return self

    def eval(self):
        super().eval()
        self.guidance_head.eval()
        self.base_lm.eval()
        return self


    def get_attn_weights(self, base_outputs):
        if base_outputs.attentions[self.extraction_layer] is not None:
            return base_outputs.attentions[self.extraction_layer]
        
        layer = self.base_lm.model.layers[self.extraction_layer]
        attn = layer.self_attn

        # get hidden states from the layer input (before attention)
        prev_hidden_states = base_outputs.hidden_states[self.extraction_layer-1]   # output of previous layer

        # compute Q and K manually
        q_proj = attn.q_proj(prev_hidden_states)
        k_proj = attn.k_proj(prev_hidden_states)

        # Detect FlashAttention fused Q case
        if q_proj.shape[-1] > k_proj.shape[-1]:
            # slice out first hidden_dim
            hidden_dim = k_proj.shape[-1]
            q_proj = q_proj[..., :hidden_dim]

        # group into heads
        batch, seq_len, hidden_dim = k_proj.shape
        num_heads = attn.num_heads
        head_dim = hidden_dim // num_heads

        q_proj = q_proj.view(batch, seq_len, num_heads, head_dim).transpose(1, 2) # (b, h, seq-len, d)
        k_proj = k_proj.view(batch, seq_len, num_heads, head_dim).transpose(1, 2) # (b, h, seq-len, d)

        # compute attention scores manually
        attn_scores = torch.matmul(q_proj, k_proj.transpose(-1, -2)) / (head_dim ** 0.5)
        causal_mask = torch.tril(torch.ones(seq_len, seq_len, device=attn_scores.device))
        attn_scores = attn_scores.masked_fill(causal_mask == 0, float('-inf'))
        
        return F.softmax(attn_scores.float()/self.pool_temp, dim=-1).to(attn_scores.dtype)

    def regularizer(self):
        guidance_params = [p for p in self.guidance_head.parameters() if p.requires_grad]
        return sum(p.pow(2).sum() for p in guidance_params)


CONFIG_MAPPING.register("classifier_guided_causal_lm", ClassifierGuidedCausalLMConfig)
MODEL_FOR_CAUSAL_LM_MAPPING.register(ClassifierGuidedCausalLMConfig, ClassifierGuidedCausalLM)

