import hashlib
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, TextIO, Union, Callable

import torch
from transformers import AutoTokenizer
from src.configs import ModelConfig
from src.personas import Persona
from src.bias_pipeline.questionaires.questionaire import Question


@dataclass
class Annotation:
    """
    Data model to represent an annotation.
    """

    model_id: str  # Which model is the judge
    annotation_subset: str  # Which subset of the data is being annotated (ususally the name of the assistant model)
    annotation: Dict[str, str]

    def __repr__(self) -> str:
        return f"{self.model_id} ({self.annotation_subset}): {self.annotation}"

    def to_json(self) -> Dict[str, Any]:
        return {
            "model_id": self.model_id,
            "annotation_subset": self.annotation_subset,
            "annotation": self.annotation,
        }

    @classmethod
    def from_json(cls, data: Dict[str, Any]) -> "Annotation":
        model_id = data["model_id"]
        annotation_subset = data["annotation_subset"]
        annotation = data["annotation"]
        return cls(model_id, annotation_subset, annotation)


@dataclass
class Message:
    text: str
    sender: str
    id: Optional[int] = None

    def __repr__(self) -> str:
        return f"{self.text}"

    def to_json(self) -> Dict[str, Any]:
        return {
            "text": self.text,
            "sender": self.sender,
            "id": self.id,
        }

    @classmethod
    def from_json(cls, data: Dict[str, Any]) -> "Message":
        text = data["text"] if "text" in data else data["example"]
        sender = data["sender"]
        id = data.get("id", None)
        return cls(text, sender, id)

    # Hashable
    def __hash__(self) -> int:
        hash_str = self.text + self.sender
        return int(hashlib.sha1(hash_str.encode("utf-8")).hexdigest(), 16)

    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, Message):
            return False
        if self.id is not None and other.id is not None:
            if self.id == other.id:
                assert self.__hash__() == other.__hash__()
                return True
            return False
        return self.__hash__() == other.__hash__()

    def to_file(self, file: TextIO) -> None:
        file.write(json.dumps(self.to_json()) + "\n")
        file.flush()


@dataclass
class RootMessage(Message):
    """
    Data model to represent the root message of a conversation.
    Inherits from Message and contains score information and a question.
    """

    def __init__(
        self,
        text: str,
        sender: str,
        question: Question,
        score: Dict[str, Any],
        id: Optional[int] = None,
    ):
        super().__init__(text, sender, id)

        self.score = score
        self.question = question

    @classmethod
    def from_json(cls, data: Dict[str, Any]) -> "RootMessage":
        text = data.get("text", data.get("example", ""))
        sender = "root"
        question = Question.from_json(data["question"])
        score = data.get("score", {})
        id = question.get_id() if question else None

        return cls(text, sender, question, score, id)

    def to_json(self) -> Dict[str, Any]:
        data = super().to_json()
        data["score"] = self.score
        data["question"] = self.question.to_json() if self.question else None
        return data


@dataclass
class Thread:
    """
    Data model to represent a thread of messages.
    """

    messages: List[Message]
    persona: Persona
    leaf_node: Optional["MessageNode"] = None

    def __init__(
        self, messages: List[Message], persona: Persona, leaf_node: Optional["MessageNode"] = None
    ):
        self.messages = messages
        self.persona = persona
        self.leaf_node = leaf_node

    def __repr__(self) -> str:
        return f"<Thread | Messages: {len(self.messages)}>"

    def __len__(self) -> int:
        return len(self.messages)

    def __getitem__(self, idx: int) -> Message:
        return self.messages[idx]

    def __iter__(self):
        return iter(self.messages)

    def __contains__(self, item: Message) -> bool:
        return item in self.messages

    def to_json(self) -> Dict[str, Any]:
        return [m.to_json() for m in self.messages]

    @classmethod
    def from_json(cls, data: List[Dict[str, Any]]) -> "Thread":
        return cls([Message.from_json(m) for m in data])

    def to_string(self, start: int = 0) -> str:
        return "\n".join([f"{m.sender}:\n {m.text}" for m in self.messages[start:]])

    def add_message(self, message: Message) -> "Thread":
        # New message is always the last message

        assert message.sender != self.leaf_node.message.sender, (
            "Cannot add message from same sender consecutively"
        )

        node = self.leaf_node.add_child(message)
        new_thread = Thread(self.messages + [message], self.persona, node)

        return new_thread

    def to_chat(
        self,
        model: ModelConfig,
        tokenizer: Optional[AutoTokenizer] = None,
        tokenize: bool = False,
        adjust_perspective: bool = True,
    ) -> tuple[List[dict[str, Any]], Optional[str], Optional[torch.Tensor]]:
        """Puts the Thread in chat format.

        Args:
            tokenizer (Optional[AutoTokenizer], optional): The tokenizer to use. Defaults to None.
            tokenize (bool, optional): Whether to also tokenize the input. Defaults to False.
            adjust_perspective (bool, optional): Whether to adapt the roles such that it is the assistants turn. Defaults to True.

        Returns:
            tuple[List[dict[str, Any]], Optional[str], Optional[torch.Tensor]]: _description_
        """

        chat_format = [{"role": "system", "content": model.system_prompt}]

        for i, message in enumerate(self.messages):
            sender = (
                "user"
                if message.sender == self.persona.name or message.sender == "user"
                else "assistant"
            )
            chat_format.append({"role": sender, "content": message.text})

        last_sender = chat_format[-1]["role"]

        if adjust_perspective and last_sender == "assistant":
            for message in chat_format[1:]:
                message["role"] = "assistant" if message["role"] == "user" else "user"

        if tokenize:
            if tokenizer is None:
                tokenizer = AutoTokenizer.from_pretrained(self.model.tokenizer_name)

            assert tokenizer is not None, "Tokenizer must be provided to tokenize the chat"

            tokenized_chat = tokenizer.apply_chat_template(
                chat_format,
                tokenize=True,
                add_generation_prompt=True,
                return_tensors="pt",
            )

            return chat_format, tokenized_chat["text"], tokenized_chat["input_ids"]
        else:
            return chat_format, None, None


class MessageNode:
    # Represents an internal conversation state that ends in a message
    def __init__(self, message: Message, parent: Optional["MessageNode"] = None):
        self.message = message
        assert isinstance(message, Message), "Message must be of type Message"
        self.parent = parent
        self.children = []

    def add_child(self, message: Union[Message, "MessageNode"]) -> None:

        if isinstance(message, MessageNode):
            assert message.parent is self, "Parent must be the current node"
            self.children.append(message)
        else:
            node = MessageNode(message, self)
            self.children.append(node)

    def get_conversation(self) -> Thread:
        conversation = [self.message]
        parent = self.parent
        while parent is not None:
            conversation.append(parent.message)
            parent = parent.parent

        conversation.reverse()
        # Make Thread object
        thread = Thread(conversation, conversation[0].sender, self)

        return thread

    def get_leaf_nodes(self) -> List["MessageNode"]:
        if len(self.children) == 0:
            return [self]
        else:
            leaf_nodes = []
            for child in self.children:
                leaf_nodes.extend(child.get_leaf_nodes())
            return leaf_nodes

    def depth(self) -> int:
        depth = 0
        parent = self.parent
        while parent is not None:
            depth += 1
            parent = parent.parent
        return depth

    def __repr__(self) -> str:
        return f"<MessageNode | Message: {self.message} | Children: {len(self.children)}>"

    def __hash__(self) -> int:
        return self.message.__hash__()

    def __eq__(self, other: Any) -> bool:
        if not isinstance(other, MessageNode):
            return False
        return self.message.__hash__() == other.message.__hash__()

    def __len__(self) -> int:
        return len(self.children)

    def __getitem__(self, idx: int) -> "MessageNode":
        return self.children[idx]

    def __iter__(self):
        return iter(self.children)

    def __contains__(self, item: Message) -> bool:
        return item in [c.message for c in self.children]

    def to_json(self) -> Dict[str, Any]:
        return {
            "message": self.message.to_json(),
            "children": [c.to_json() for c in self.children],
        }

    @classmethod
    def from_json(
        cls, data: Dict[str, Any], parent: Optional["MessageNode"] = None
    ) -> "MessageNode":
        # Recursively build the tree
        message = Message.from_json(data["message"])
        node = cls(message, parent)
        for child in data["children"]:
            node.add_child(message=cls.from_json(child, node))
        return node
