"""Implement hate speach models."""

import torch
from decision.xp.model.base import BaseModel, PretrainedMixin, register_model
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    AutoModelForCausalLM,
)
from torch.nn import Identity
import torch.nn.functional as F


def getattr_nested(obj, attr_name):
    attributes = attr_name.split(".")
    for attr in attributes:
        obj = getattr(obj, attr)
    return obj


def setattr_nested(obj, attr_name, value):
    attributes = attr_name.split(".")
    for attr in attributes[:-1]:  # Go until the second-to-last attribute
        obj = getattr(obj, attr)
    setattr(obj, attributes[-1], value)  # Set the value on the last attribute


class BaseSequenceClassification(PretrainedMixin, BaseModel):
    clf_layer_name = "classifier"

    def __init__(self):
        self._truncated_model = None
        self._remaining_model = None
        self._tokenizer = None

    @property
    def tokenizer(self):
        if self._tokenizer is not None:
            return self._tokenizer
        self._tokenizer = self.load_tokenizer()
        return self._tokenizer

    @property
    def truncated_model(self):
        if self._truncated_model is not None:
            return self._truncated_model
        self._truncated_model, self._remaining_model = self.load_split_model()
        return self._truncated_model

    @property
    def remaining_model(self):
        if self._remaining_model is not None:
            return self._remaining_model
        self._truncated_model, self._remaining_model = self.load_split_model()
        return self._remaining_model

    @classmethod
    def load_whole_model(cls):
        return AutoModelForSequenceClassification.from_pretrained(cls.model_name)

    @classmethod
    def load_tokenizer(cls):
        return AutoTokenizer.from_pretrained(cls.model_name)

    @classmethod
    def load_split_model(cls):
        truncated_model = cls.load_whole_model()
        remaining_model = getattr_nested(truncated_model, cls.clf_layer_name)
        setattr_nested(truncated_model, cls.clf_layer_name, Identity())
        return truncated_model, remaining_model

    def process(self, examples: list) -> dict:
        return self.tokenizer(
            examples, truncation=True, max_length=512, padding=True, return_tensors="pt"
        )

    @staticmethod
    def get_logits(outputs):
        return outputs

    def eval(self):
        self.truncated_model.eval()
        self.remaining_model.eval()

    def train(self):
        self.truncated_model.train()
        self.remaining_model.train()

    def to(self, device):
        self.truncated_model.to(device)
        self.remaining_model.to(device)

    def forward_truncated(self, **kwargs):
        return self.truncated_model(**kwargs)

    @staticmethod
    def _forward_remaining(remaining_model, latent_space):
        return remaining_model(latent_space)

    def forward_remaining(self, latent_space):
        return self._forward_remaining(self.remaining_model, latent_space)

    def forward_whole(self, **kwargs):
        output = self.forward_truncated(**kwargs)
        latent_space = self.get_latent_space(output)
        output = self.forward_remaining(latent_space)
        return output

    def forward_both(self, **kwargs):
        output_truncated = self.forward_truncated(**kwargs)
        latent_space = self.get_latent_space(output_truncated)
        output_remaining = self.forward_remaining(latent_space)
        return output_truncated, output_remaining

    def get_w(self):
        return self.remaining_model.weight.detach().numpy()


@register_model("cnerg1", "CNERG tamil")
class CNERG1(BaseSequenceClassification):
    model_name = "Hate-speech-CNERG/tamil-codemixed-abusive-MuRIL"


@register_model("cnerg2", "CNERG en MuRIL")
class CNERG2(BaseSequenceClassification):
    model_name = "Hate-speech-CNERG/english-abusive-MuRIL"


@register_model("cnerg3", "CNERG en mono")
class CNERG3(BaseSequenceClassification):
    model_name = "Hate-speech-CNERG/dehatebert-mono-english"


@register_model("cnerg4", "CNERG Hatexplain")
class CNERG4(BaseSequenceClassification):
    model_name = "Hate-speech-CNERG/bert-base-uncased-hatexplain"

    def get_probabilities(self, outputs):
        # The model output class is hate/normal/offensive
        # so we merge the probabilities of offensive and hate
        # as the probability of the positive class
        logits = self.get_logits(outputs)
        probas = F.softmax(logits, dim=-1)
        probas_normal = probas[..., 1:2]
        probas_hate = probas[..., 0:1] + probas[..., 2:3]
        probas_truncated = torch.cat([probas_normal, probas_hate], dim=-1)
        return probas_truncated

    def get_y_pred(self, outputs):
        return self.get_probabilities(outputs).argmax(-1)


@register_model("cnerg5", "CNERG portuguese")
class CNERG5(BaseSequenceClassification):
    model_name = "Hate-speech-CNERG/dehatebert-mono-portugese"


@register_model("fb_roberta1", "FB Roberta 1")
class FBRoberta1(BaseSequenceClassification):
    model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
    clf_layer_name = "classifier"

    @staticmethod
    def get_latent_space(truncated_outputs):
        latent_space = truncated_outputs.logits
        return latent_space[:, 0, :]

    @staticmethod
    def _forward_remaining(remaining_model, latent_space):
        # Since we already extracted the class token in get_latent_space,
        # we need to add a dimension to the latent space to make it compatible
        # with the remaining model
        latent_space = latent_space.unsqueeze(1)
        return BaseSequenceClassification._forward_remaining(
            remaining_model, latent_space
        )

    def get_w(self):
        return self.remaining_model.dense.weight.detach().numpy()


@register_model("fb_roberta2", "FB Roberta")
class FBRoberta2(BaseSequenceClassification):
    model_name = "facebook/roberta-hate-speech-dynabench-r4-target"
    clf_layer_name = "classifier.out_proj"


@register_model("mistral_instruct", "Mistral-7B-Instruct-v0.2")
class Mistral(BaseSequenceClassification):
    model_name = "mistralai/Mistral-7B-Instruct-v0.2"
    clf_layer_name = "lm_head"

    def __init__(self):
        super().__init__()
        self.yes_token_id = self.tokenizer.convert_tokens_to_ids("yes")
        self.no_token_id = self.tokenizer.convert_tokens_to_ids("no")
        # Tokenizer as no pad_token defined
        self.tokenizer.pad_token = self.tokenizer.eos_token

    @classmethod
    def load_whole_model(cls):
        return AutoModelForCausalLM.from_pretrained(cls.model_name)

    def process(self, examples: list) -> dict:
        def get_prompt(s):
            return f"You are a hate speech detector. Given an input sentence you have to decide whether it is hate speech or not. You should only answer with one word: yes or no. The sentence to classify is: '{s}'. Is the previous sentence hate speech?"

        prompts = [get_prompt(s) for s in examples]
        return self.tokenizer(prompts, padding=True, return_tensors="pt")

    @staticmethod
    def get_latent_space(truncated_outputs):
        latent_space = truncated_outputs.logits
        return latent_space[:, -1, :]

    # @staticmethod
    # def get_logits(outputs):
    #     # Select the last token logit for next token prediction
    #     return outputs
    #     # logits = outputs.logits[:, -1, :]
    #     # truncated_logits = logits[:, [self.no_token_id, self.yes_token_id]]
    #     # return truncated_logits

    def get_probabilities(self, outputs):
        logits = self.get_logits(outputs)
        truncated_logits = logits[:, [self.no_token_id, self.yes_token_id]]
        return F.softmax(truncated_logits, dim=-1)

    def get_y_pred(self, outputs):
        return self.get_probabilities(outputs).argmax(-1)
