import json
from abc import ABC, abstractmethod
from datetime import datetime
from typing import Callable

import re
import os
import hashlib

from fschat.conversation_game import Conversation

def generate_hash(text: str) -> str:
    """Generate a 4-character hash from a given string."""
    return hashlib.md5(text.encode()).hexdigest()[:4]

def question_header_in_output_stream(s):
    pattern = r'question \d+:'
    #if len(re.findall(pattern, s.lower())) !=0 and int(list(re.findall(number, s.lower()))[0]) == n:
    if len(re.findall(pattern, s.lower())) !=0:
        return True
    else:
        return False

def guess_in_output_stream(s):
    pattern = r"my guess of the word is:"
    if len(re.findall(pattern, s.lower())) != 0:
        return True
    else:
        return False


class BaseGame(ABC):
    def __init__(self, max_round: int, save_path: str) -> None:
        self.max_round = max_round
        self.save_path = save_path

        self.model_name = None
        self.game_name = ""
        self.game_rule = ""
        self.game_start = False
        self.show_message_history = False
        self.generate_next_llm_query = False
        self.next_llm_query_type = None
        self.game_status = "ONGOING"
        self.round = 0

        self.game_session_id = None
        self.user_rating = None

        self.system_prompt = ""
        self.first_user_message = ""
        self.secret_system_message = None

    def initialize_game(self, conversation: Conversation) -> None:
        conversation.append_message(conversation.roles[0], self.first_user_message)

    def generation_response(
        self,
        type: str,
        stream_iter_fn: Callable,
        conversation: Conversation,
        model_name: str,
        model_api_info: dict,
        temperature: float = 0.0,
        top_p: float = 1.0,
        max_new_tokens: int = 1024,
        state=None,
        use_recommended_config: bool = False,
    ) -> str:

        if use_recommended_config:
            recommended_config = model_api_info.get("recommended_config", None)
            if recommended_config is not None:
                temperature = recommended_config.get("temperature", 0.0)
                top_p = recommended_config.get("top_p", 1.0)
        # Generating new question
        print(model_name)
        prefix = None
        if type == 'question':
            prefix = f'Question {self.round + 1}:'
        elif type == 'answer':
            prefix = ' '
        elif type == 'taboo_guess':
            prefix = 'my guess of the word is:'
        else:
            raise NotImplementedError(f"response type: {type} is not implemented.")
        
        # API-dependent implementation, some APIs like mistral doens't accept a standalone per-turn prefix as input to the model
        # without prefix guiding, mistral doesn't do instruction-following...
        # TODO: add prefix to other API use as well, and keep displayed output consistent (no repetition of prefix been generated)
        if 'mistral' in model_name:
            conversation.append_message(
                 conversation.roles[1], prefix
            )
        else:
            conversation.append_message(
                 conversation.roles[1], None
            )
        
        stream_iter = stream_iter_fn(
                conversation,
                model_name,
                model_api_info,
                temperature=temperature,
                top_p=top_p,
                max_new_tokens=max_new_tokens,
                state=state,
            )

        # some APIs' output streams won't include the given prefix, including 'claude', 'openai' and 'replicate'
        #if 'claude' in model_name or 'llama-3' in model_name or 'gpt' in model_name:
        #    if prefix != "":
        #        yield prefix + ' '
        
        # notice that openai API's output stream won't include the given prefix
        # we rely on the system prompt to have the model generate 

        prev_generation = None

        for i, data in enumerate(stream_iter):
            assert data["error_code"] == 0
            output_stream = data["text"].strip()
            
            if prev_generation is None:
                yield output_stream
                prev_generation = output_stream
            else:
                yield output_stream[len(prev_generation): ]
                prev_generation = output_stream
        
        output = data["text"].strip()

        # some APIs' output streams won't include the given prefix, including 'claude', 'openai' and 'replicate'
        # need to manually add the prefix to conversational history

        # 'gemini' doesn't repeat prefix
        # 'claude' doesn't repeat prefix
        # 'llama-3' with replicate API doesn't repeat prefix
        # 'gpt' doesn't repeat prefix
        # 'mistral' doesn't permits prefix, it generates the question number along the way
        if type == 'question':
            if question_header_in_output_stream(output):
                conversation.update_last_message(output)
            else:
                conversation.update_last_message(prefix + ' ' + output)
        elif type == 'taboo_guess':
            if guess_in_output_stream(output):
                conversation.update_last_message(output)
            else:
                conversation.update_last_message(prefix + ' ' + output)
        else:
            conversation.update_last_message(output)
        
        #if not self.is_llm_giving_answer(conversation):
        self.round += 1

    def update_conversation_with_user_choice(
        self, conversation: Conversation, user_choice: str
    ) -> None:
        # conversation.roles[0] == "USER"
        # conversation.roles[1] == "ASSISTANT"
        conversation.append_message(conversation.roles[0], user_choice)

    @abstractmethod
    def is_llm_giving_answer(self, conversation: Conversation) -> bool:
        pass
    
    def is_llm_triggering_termination(self, conversation: Conversation) -> bool:
        pass
    
    def is_llm_illegal_input(self, input_text: str) -> bool:
        pass

    def set_end_game_status(self, user_response: str) -> None:
        if user_response == "MODEL_WIN":
            self.game_status = "MODEL_WIN"
        elif user_response == "MODEL_LOSE":
            self.game_status = "MODEL_LOSE"
        elif user_response == "MAX_ROUND_REACHED":
            self.game_status = "MAX_ROUND_REACHED"
        else:
            raise ValueError("Invalid user_last_answer")

    def post_game_data_collection(
        self, conversation: Conversation, game_session_id: str = "", user_note: str = ""
    ) -> None:
        try:
            exisiting_data = json.load(open(self.save_path, "r"))
        except FileNotFoundError:
            exisiting_data = []
        except json.JSONDecodeError:
            exisiting_data = []
            
            # Generate the 4-character hash based on a unique identifier, such as the current file name or a timestamp
            base_name = os.path.basename(self.save_path).split('.')[0]
            file_hash = generate_hash(base_name)
            self.save_path = f"output_{file_hash}.json"

        conversation_dict = conversation.dict()
        # include time
        current_datetime = datetime.now()
        formatted_datetime = current_datetime.strftime("%Y-%m-%d %H:%M:%S")
        conversation_dict["date"] = formatted_datetime
        conversation_dict["system_prompt_index"] = self.system_prompt_index

        # include game_session_id
        conversation_dict["game_session_id"] = self.game_session_id
        conversation_dict["user_rating"] = self.user_rating

        conversation_dict["game_name"] = self.game_name
        if self.game_name == "Bluffing":
            conversation_dict["system_question"] = self.system_question
        conversation_dict["game_rule"] = self.game_rule
        conversation_dict["game_status"] = self.game_status
        conversation_dict["round"] = self.round

        # attach user feedback
        if user_note != "":
            conversation_dict["correct_answer"] = user_note
        
        # append game secret
        if self.game_name != "Akinator":
            conversation_dict["ground_truth_secret"] = self.secret_system_message
        exisiting_data.append(conversation_dict)

        print("Saving data to: ", self.save_path)
        try:
            with open(self.save_path, "w") as f:
                json.dump(exisiting_data, f, indent=4, ensure_ascii=False)
        except:
            pass

    def reach_max_round(self) -> bool:
        if self.round >= self.max_round:
            return True
        return False
