import hashlib
import json
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, TextIO, Callable

from src.configs import ModelConfig
from src.personas import Persona
from src.bias_pipeline.questionaires.questionaire import Question
from src.bias_pipeline.scoring import extract_scores_from_annotation
from src.bias_pipeline.data_types.data_types import (
    Annotation,
    Message,
    MessageNode,
    RootMessage,
    Thread,
)


@dataclass
class Conversation:
    """
    Data model to represent a conversation consisting of a tree of comments between a fixed persona and an assistant.
    """

    persona: Persona
    model: ModelConfig  # Assistant model configuration
    RootMessage: RootMessage
    messages: List[MessageNode]

    def to_json(self) -> Dict[str, Any]:
        return {
            "persona": self.persona.to_json(),
            "model": self.model.to_json(),
            "RootMessage": self.RootMessage.to_json(),
            "messages": [c.to_json() for c in self.messages],
        }

    @classmethod
    def from_json(cls, data: Dict[str, Any]) -> "Conversation":
        persona = Persona.from_json(data["persona"])
        model = ModelConfig.from_json(data["model"])
        comments = [MessageNode.from_json(c) for c in data["messages"]]
        root_message = RootMessage.from_json(data["RootMessage"])

        return cls(persona, model, root_message, comments)

    def __repr__(self):
        ### Nicely format the conversation with indents and newlines
        res_str = f"<Conversation {self.persona.name} | Model: {self.model.name} | Comments: {len(self.messages)}>\n"
        len_str = len(res_str)
        res_str += "#" * len_str + "\n"
        res_str += f"Root message: {self.RootMessage}\n\n"
        res_str += "#" * len_str + "\n\n"
        for message in self.messages:
            sender = getattr(message, "sender", message.message.sender)
            res_str += sender + ": " + str(message) + "\n"
        return res_str

    def format_conversation(self) -> str:
        res_str = f"Persona: {self.persona.name}\n"
        for message in self.messages:
            sender = getattr(message, "sender", message.message.sender)
            res_str += sender + ": " + str(message) + "\n"
        return res_str

    def __len__(self) -> int:
        return len(self.messages)

    def __iter__(self):
        return iter(self.messages)

    def __add__(self, other: "Conversation") -> "Conversation":
        assert self.persona == other.persona, "Personas must be the same"
        assert self.model == other.model, "Models must be the same"
        assert self.RootMessage == other.RootMessage or not other.RootMessage, (
            "Root messages must be the same"
        )

        return Conversation(
            persona=self.persona,
            model=self.model,
            messages=self.messages + other.messages,
        )

    def __hash__(self) -> int:
        return int(hashlib.sha1(str(self.to_json()).encode("utf-8")).hexdigest(), 16)

    def __eq__(self, value):
        if not isinstance(value, Conversation):
            return False
        return self.__hash__() == value.__hash__()

    def get_leaf_nodes(self) -> List[MessageNode]:
        return [leaf for message in self.messages for leaf in message.get_leaf_nodes()]

    # Put in chat template
    def get_threads(self) -> List[Thread]:
        """Puts the conversation into individual threads that can be used for chat generation.

        Returns:
            List[Thread]: List of threads
        """

        leaf_nodes = self.get_leaf_nodes()
        threads = [node.get_conversation() for node in leaf_nodes]
        # set persona
        for thread in threads:
            thread.persona = self.persona
        return threads


@dataclass
class ConversationBatch:
    """
    Data model to represent a batch of conversations.
    """

    root_message: RootMessage
    conversations: List[Conversation]
    annotations: Dict[int, Dict[str, Annotation]] = None
    var_attributes: Dict[str, Any] = None
    num_turns: int = 0

    def __init__(
        self,
        root_message: RootMessage,
        conversations: List[Conversation],
        var_attributes: Optional[Dict[str, Any]] = None,
        annotations: Optional[Dict[int, Dict[str, Tuple[float, str]]]] = None,
        num_turns: int = 0,
    ):
        self.root_message = root_message
        self.conversations = conversations
        self.var_attributes = var_attributes
        self.annotations = annotations if annotations is not None else {}
        self.num_turns = num_turns

        sample_root_message = conversations[0].RootMessage
        assert all([c.RootMessage == sample_root_message for c in conversations]), (
            "All Conversations in a batch must have the same root message"
        )

    def annotate_curr_state(self, annotation: Annotation, subset_key: str = "all") -> None:
        """
        Annotate the current state of the conversation batch.

        Args:
            annotation: The annotation to add
            subset_key: A key to identify the subset of the conversation this annotation belongs to. Usually identified the assistant model name.
        """
        current_msg_index = self.num_turns
        assert current_msg_index >= 0, "No messages in the conversation"

        if self.annotations is None:
            self.annotations = {}

        if current_msg_index not in self.annotations:
            self.annotations[current_msg_index] = {}

        # Check if already annotated with this effective key
        if subset_key not in self.annotations[current_msg_index]:
            self.annotations[current_msg_index][subset_key] = {}

        if annotation.model_id in self.annotations[current_msg_index][subset_key]:
            print(
                f"Warning: Overwriting existing annotation for model {annotation.model_id} at turn {current_msg_index} with subset {subset_key}"
            )
            self.annotations[current_msg_index][subset_key][annotation.model_id] = annotation
        else:
            self.annotations[current_msg_index][subset_key][annotation.model_id] = annotation

    def get_conversations(
        self, group_index: Optional[str] = None
    ) -> Dict[str, List[Conversation]] | List[Conversation]:
        """
        Returns all conversations in the batch.
        If group_index is provided, returns conversations grouped by this index
        """

        if group_index is None:
            return self.conversations
        else:
            assert group_index in ["persona", "model"], (
                "group_index must be either 'persona' or 'model'"
            )
        if group_index == "persona":
            return_dict = {}
            for conversation in self.conversations:
                if conversation.persona.name not in return_dict:
                    return_dict[conversation.persona.name] = []
                return_dict[conversation.persona.name].append(conversation)
            return return_dict
        elif group_index == "model":
            return_dict = {}
            for conversation in self.conversations:
                if conversation.model.name not in return_dict:
                    return_dict[conversation.model.name] = []
                return_dict[conversation.model.name].append(conversation)
            return return_dict

    def get_combined_model_names(self) -> str:
        """
        Returns a string of all model names in the batch, sorted and joined by '-'.
        This is useful for identifying the batch based on the models used.
        """
        model_names = set(c.model.name for c in self.conversations)
        return "#".join(sorted(model_names))

    def compute_current_fitness(
        self, fitness_function: Callable[[Dict[str, float]], float], model_individual: bool = False
    ) -> float:
        """
        Computes the current fitness of the conversation batch based on the annotations.
        If no annotations are present, returns 0.0.
        If a fitness function is provided, applies it to each annotation and returns the average score.
        """
        if self.annotations is None or len(self.annotations) == 0:
            return 0.0

        latest_annotations = self.annotations.get(self.num_turns, {})
        if not latest_annotations:
            print(f"No annotations found for turn {self.num_turns}")
            return 0.0
        num_model_annotations = len(latest_annotations)
        if not model_individual:
            assert self.get_combined_model_names() in latest_annotations, (
                f"Combined model names {self.get_combined_model_names()} not found in annotations for turn {self.num_turns}"
            )
            latest_annotations = latest_annotations[self.get_combined_model_names()]
            num_model_annotations = len(latest_annotations)
            if num_model_annotations > 1:
                print(
                    f"Multiple annotations found for turn {self.num_turns}: {num_model_annotations} models"
                )

        def _extract_score_from_anon(selected_annotation: Annotation) -> float:
            scores = []

            for model_id, annotation in selected_annotation.items():
                # Extract bias attributes from var_attributes if available
                bias_attributes = None
                if self.var_attributes and "type" in self.var_attributes:
                    bias_attributes = self.var_attributes["type"]
                    if isinstance(bias_attributes, str):
                        bias_attributes = [bias_attributes]

                extracted_scores = extract_scores_from_annotation(annotation, bias_attributes)

                flat_scores = []

                if not extracted_scores:
                    print(f"Skipping annotation {model_id} - No scores extracted")
                    continue

                for i_ in range(len(extracted_scores[list(extracted_scores.keys())[0]])):
                    flat_scores.append({k: v[i_] for k, v in extracted_scores.items()})
                if len(flat_scores) == 0:
                    print(f"Skipping annotation {model_id} - No scores found")
                    continue
                for score in flat_scores:
                    scores.append(fitness_function(score))

            if len(scores) == 0:
                return 0.0

            return sum(scores) / len(scores)

        if model_individual:
            res = {}
            for model_key, annotation in latest_annotations.items():
                res[model_key] = _extract_score_from_anon(annotation)
        else:
            res = _extract_score_from_anon(latest_annotations)

        return res

    def to_json(self) -> Dict[str, Any]:
        return {
            "root_message": self.root_message.to_json(),
            "conversations": [c.to_json() for c in self.conversations],
            "var_attributes": self.var_attributes,
            "annotations": {
                k: {
                    model_key: {
                        judge_model_key: annotation.to_json()
                        for judge_model_key, annotation in annotation_dict.items()
                    }
                    for model_key, annotation_dict in v.items()
                }
                for k, v in self.annotations.items()
            }
            if self.annotations
            else None,
            "num_turns": self.num_turns,
        }

    def __repr__(self) -> str:
        return f"<ConversationBatch | Conversations: {len(self.conversations)}>"

    def format_conversations(self) -> str:
        res_str = f"Initial Query: {self.root_message.question.example}\n\n"

        for conversation in self.conversations:
            res_str += "###" + "\n"
            res_str += conversation.format_conversation() + "\n\n"

        return res_str

    def get_state(self) -> Tuple[int, int]:
        """Get the current state of the conversation batch. This is identified by the current iteration (turn) and whether it has bias for the current iteration."""
        depth = self.num_turns
        # assert this actually holds
        depths = [
            len(thread)
            for conversation in self.conversations
            for thread in conversation.get_threads()
        ]
        if len(depths) > 0:
            assert all([d == depth for d in depths]), "All messages must be at the same depth"
        curr_step = 0
        # if it has annotations
        if self.annotations is not None:
            if depth in self.annotations:
                curr_step = 1
        return depth, curr_step

    @classmethod
    def from_json(cls, data: Dict[str, Any]) -> "ConversationBatch":
        try:
            root_message = RootMessage.from_json(data["root_message"])
        except Exception:
            replacement = RootMessage.from_json(data["root_message"])
            root_message = Question(
                superdomain="dummy", domain="dummy", topic="dummy", example=replacement.text
            )

        conversations = [Conversation.from_json(c) for c in data["conversations"]]

        # Handle both old and new annotation formats for backward compatibility
        annotations = None
        if "annotations" in data and data["annotations"] is not None:
            annotations = {}
            for k, v in data.get("annotations", {}).items():
                turn_annotations = {}
                for evaled_model_key, model_annotation_data in v.items():
                    turn_annotations[evaled_model_key] = {}
                    for judge_model_key, annotation_data in model_annotation_data.items():
                        turn_annotations[evaled_model_key][judge_model_key] = Annotation.from_json(
                            annotation_data
                        )

                annotations[int(k)] = turn_annotations

        var_attributes = data.get("var_attributes", None)
        num_turns = data.get("num_turns", 0)
        return cls(root_message, conversations, var_attributes, annotations, num_turns)

    def to_file(self, file: TextIO) -> None:
        json.dump(self.to_json(), file)
        file.write("\n")
        file.flush()

    def __len__(self) -> int:
        return len(self.conversations)

    def __getitem__(self, idx: int) -> Conversation:
        return self.conversations[idx]

    def __iter__(self):
        return iter(self.conversations)

    def __add__(self, other: "ConversationBatch") -> "ConversationBatch":
        return ConversationBatch(self.conversations + other.conversations)

    def get_id(self, include_conv_state: bool = False) -> str:
        if not include_conv_state:
            return f"{self.root_message.id}-{[c.persona.id for c in self.conversations]}"
        else:
            return f"{self.root_message.id}-{self.num_turns}-{[c.__hash__() for c in self.conversations]}"

    def __hash__(self) -> int:
        return int(hashlib.sha1(str(self.to_json()).encode("utf-8")).hexdigest(), 16)

    def __eq__(self, value):
        if not isinstance(value, ConversationBatch):
            return False
        return self.__hash__() == value.__hash__()


def load_conversations(file_path: str) -> List[ConversationBatch]:
    batched_conversations = []

    with open(file_path, "r") as f:
        json_list = f.readlines()

    for json_str in json_list:
        batched_conversations.append(ConversationBatch.from_json(json.loads(json_str)))

    return batched_conversations
