import torch
import torch.nn as nn
from transformers.models.llama.modeling_llama import LlamaForSequenceClassification
from sklearn.tree import DecisionTreeClassifier
import os
import pickle
import json
from huggingface_hub import hf_hub_download
from typing import List, Dict, Union, Optional
import numpy as np

def convert_to_chat_format(prompt, response=None):
    if "<extra_id_1>" in prompt:
        """
        Handling HelpSteer2 prompts which may contain multi-turn conversations with the special token <extra_id_1>
        """
        turns = prompt.split("<extra_id_1>")
        conversation = []
        conversation.append({
            "role": "user",
            "content": turns[0]
        })
        
        for i in range(1, len(turns)):
            parts = turns[i].split("\n", 1)
            role = parts[0]
            content = parts[1]
            conversation.append({
                "role": "assistant" if role == "Assistant" else "user",
                "content": content
            })
    else:
        conversation = [{"role": "user", "content": prompt}]
    if response is not None:
        conversation.append({"role": "assistant", "content": response})
    return conversation

def process_conversation(conversation):
    for message in conversation:
        message["content"] = message["content"].rstrip('\n')
    return conversation

from collections import namedtuple

Result = namedtuple("Result", ["logits"])

class LlamaForDecisionTreeRewardModel(LlamaForSequenceClassification):
    def __init__(self, config):
        super().__init__(config)
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=True)
        # Initialize the decision tree
        self.tree = None
        # Define the default attributes (from HelpSteer2)
        self.attributes = ['helpfulness', 'correctness', 'coherence', 'complexity', 'verbosity']
        print("Initialized LlamaForDecisionTreeRewardModel")
    
    def load_decision_tree(self, repo_id, filename="decision_tree.pkl"):
        # Load the tree from the model's directory
        with open(hf_hub_download(repo_id=repo_id, filename=filename), "rb") as f:
            self.tree = pickle.load(f)
            assert isinstance(self.tree, DecisionTreeClassifier), f"The tree is not a DecisionTreeClassifier. It is a {type(self.tree)}"
        with open(hf_hub_download(repo_id=repo_id, filename="config.json"), "r") as f:
            config = json.load(f)
        label2id_map = config["label2id"]
        # Sort labels and ids by ids
        labels, ids = zip(*sorted(label2id_map.items(), key=lambda x: x[1]))
        labels = list(labels)
        self.attributes = labels

    @torch.no_grad()
    def compare(self, prompt: Union[str, List[Dict[str, str]]], response_1: str, response_2: str, tokenizer, device):
        """
        Compare two inputs and return the difference in scores
        """
        assert self.tree is not None, "The decision tree is not loaded. Please call load_decision_tree(repo_id, filename) first."
        if isinstance(prompt, str):
            conversation = convert_to_chat_format(prompt)
        elif isinstance(prompt, list):
            conversation = prompt
        else:
            raise ValueError(f"The prompt must be a string or a list of dictionaries, but got {type(prompt)}")
        assert isinstance(conversation, list), "The conversation must be a list of dictionaries"
        assert len(conversation) >= 1, "The conversation must have at least one message (as prompt)"
        assert conversation[-1]["role"] == "user", "The last message in the conversation must be from the user"
        conversation_1 = conversation + [{"role": "assistant", "content": response_1}]
        conversation_2 = conversation + [{"role": "assistant", "content": response_2}]
        conversation_1 = process_conversation(conversation_1)
        conversation_2 = process_conversation(conversation_2)

        conv_tokenized_1 = tokenizer.apply_chat_template(conversation_1, tokenize=True, return_tensors="pt").to(device)
        conv_tokenized_2 = tokenizer.apply_chat_template(conversation_2, tokenize=True, return_tensors="pt").to(device)
        embedding_1 = self.forward(conv_tokenized_1, output_hidden_states=True).hidden_states[-1][:,-1].float().cpu().numpy()
        embedding_2 = self.forward(conv_tokenized_2, output_hidden_states=True).hidden_states[-1][:,-1].float().cpu().numpy()
        weight = self.score.weight.float().cpu().numpy()
        bias = self.score.bias.float().cpu().numpy()
        rewards_1 = embedding_1 @ weight.T + bias
        rewards_2 = embedding_2 @ weight.T + bias
        rewards_diff = rewards_2 - rewards_1
        return {
            "preference": self.tree.predict(rewards_diff)[0],
            "rewards": np.concatenate([rewards_1, rewards_2]),
            "attributes": self.attributes
            }
    
    @torch.no_grad()
    def forward(self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,):
        """
        Score a single input
        """
        # if isinstance(prompt, str):
        #     conversation = convert_to_chat_format(prompt)
        # elif isinstance(prompt, list):
        #     conversation = prompt
        # else:
        #     raise ValueError(f"The prompt must be a string or a list of dictionaries, but got {type(prompt)}")
        
        # assert isinstance(conversation, list), "The conversation must be a list of dictionaries"
        # assert len(conversation) >= 1, "The conversation must have at least one message (as prompt)"
        # assert conversation[-1]["role"] == "user", "The last message in the conversation must be from the user"
        # conversation = conversation + [{"role": "assistant", "content": response}]
        # conversation = process_conversation(conversation)

        # conv_tokenized = tokenizer.apply_chat_template(conversation, tokenize=True, return_tensors="pt").to(device)
        embedding = super().forward(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, output_hidden_states=True).hidden_states[-1][:,-1]
        # weight = self.score.weight.float().cpu().numpy()
        # bias = self.score.bias.float().cpu().numpy()
        # rewards = embedding @ weight.T + bias
        rewards = self.score(embedding)

        # Get correctness score
        return Result(logits=rewards[:, 1:2])
