import os
import io
import random
import re
import time
from collections import defaultdict
from typing import Any, Dict, Generator, List, Literal, Optional

import numpy as np
import base64
from gymnasium import spaces
from loguru import logger

from llama_index.core.schema import ImageNode
from llama_index.core.llms import ChatMessage, ChatResponse
from rich import print
from PIL import Image
from llama_index.core.base.llms.types import CompletionResponse

from .config import (
    INDEX_TO_MOVE,
    META_INSTRUCTIONS,
    META_INSTRUCTIONS_WITH_LOWER,
    MOVES,
    NB_FRAME_WAIT,
    X_SIZE,
    Y_SIZE,
)
from .socket_config import SocketConfig
from .observer import detect_position_from_color
from .llm import get_client, get_client_multimodal
import abc


class Robot(metaclass=abc.ABCMeta):
    observations: List[Optional[Dict[str, Any]]]  # memory
    next_steps: List[int]  # action plan
    actions: dict  # actions of the agents during a step of the game
    # actions of the agents during the previous step of the game
    previous_actions: Dict[str, List[int]]
    reward: float  # reward of the agent

    action_space: spaces.Space
    character: Optional[str] = None  # character name
    side: int  # side of the stage where playing: 0 = left, 1 = right
    current_direction: Literal["Left", "Right"]  # current direction facing
    sleepy: Optional[bool] = False  # if the robot is sleepy
    only_punch: Optional[bool] = False  # if the robot only punch
    temperature: float = 0.7  # temperature of the language model

    model: str  # model of the robot

    super_bar_own: int
    player_nb: int  # player number

    def __init__(
        self,
        action_space: spaces.Space,
        character: str,
        side: int,
        character_color: list,
        ennemy_color: list,
        sleepy: bool = False,
        only_punch: bool = False,
        temperature: float = 0.7,
        model: str = "mistral:mistral-large-latest",
        player_nb: int = 0,  # 0 means not specified
        delay: float = 0.0,
        local_model = None,
        tokenizer = None,
        serving_method = None,
        socket_config: SocketConfig = None,
    ):
        self.action_space = action_space
        self.character = character
        if side == 0:
            self.current_direction = "Right"
        elif side == 1:
            self.current_direction = "Left"

        self.observations = []
        self.next_steps = []
        self.character_color = character_color
        self.ennemy_color = ennemy_color
        self.side = side
        self.sleepy = sleepy
        self.only_punch = only_punch
        self.temperature = temperature
        self.model = model
        self.previous_actions = defaultdict(list)
        self.actions = {}
        self.player_nb = player_nb
        self.delay = delay
        self.local_model = local_model 
        self.tokenizer = tokenizer
        self.serving_method = serving_method
        self.socket_config = socket_config

    def act(self) -> int:
        """
        At each game frame, we execute the first action in the list of next steps.

        An action is an integer from 0 to 18, where 0 is no action.

        See the MOVES dictionary for the mapping of actions to moves.
        """
        if not self.next_steps or len(self.next_steps) == 0:
            return 0  # No move

        if self.sleepy:
            return 0

        if self.only_punch:
            # Do a Hadouken
            if self.current_direction == "Right":
                self.next_steps.extend(
                    [
                        MOVES["Down"],
                        MOVES["Right+Down"],
                        MOVES["Right"],
                        MOVES["High Punch"],
                    ]
                )
            elif self.current_direction == "Left":
                self.next_steps.extend(
                    [
                        MOVES["Down"],
                        MOVES["Down+Left"],
                        MOVES["Left"],
                        MOVES["High Punch"],
                    ]
                )

        next_step = self.next_steps.pop(0)

        return next_step
    
    def plan(self) -> None:
        """
        The robot will plan its next steps by calling this method.

        In SF3, moves are based on combos, which are list of actions that must be executed in a sequence.

        Moves of Ken
        https://www.eventhubs.com/guides/2009/may/11/ken-street-fighter-3-third-strike-character-guide/

        Moves of Ryu
        https://www.eventhubs.com/guides/2008/may/09/ryu-street-fighter-3-third-strike-character-guide/
        """

        # If we already have a next step, we don't need to plan
        if len(self.next_steps) > 0:
            return
        # Call the LLM to get the next steps
        start = time.time()
        next_steps_from_llm = self.get_moves_from_llm()
        end = time.time()
        logger.debug(f"Time to get moves from LLM: {end - start}s")
        next_buttons_to_press = [
            button
            for combo in next_steps_from_llm
            for button in META_INSTRUCTIONS_WITH_LOWER[combo][
                self.current_direction.lower()
            ]
            # We add a wait time after each button press
            + [0] * NB_FRAME_WAIT
        ]
        self.next_steps.extend(next_buttons_to_press)

    def get_moves_from_llm(
        self,
    ) -> List[str]:
        """
        Get a list of moves from the language model.
        """

        # Filter the moves that are not in the list of moves
        invalid_moves = []
        valid_moves = []

        # If we are in the test environment, we don't want to call the LLM
        if os.getenv("DISABLE_LLM", "False") == "True":
            # Choose a random int from the list of moves
            logger.debug("DISABLE_LLM is True, returning a random move")
            return [random.choice(list(MOVES.values()))]

        while len(valid_moves) == 0:
            llm_stream = self.call_llm()

            # adding support for streaming the response
            # this should make the players faster!

            llm_response = ""
            
            for r in llm_stream:
                # print(r.delta, end="")
                
                llm_response += r.delta
                # print("resp",llm_response)
                

                # The response is a bullet point list of moves. Use regex
                matches = re.findall(r"- ([\w ]+)", llm_response)
                moves = ["".join(match) for match in matches]
                invalid_moves = []
                valid_moves = []

                for move in moves:
                    cleaned_move_name = move.strip().lower()
                    if cleaned_move_name in META_INSTRUCTIONS_WITH_LOWER.keys():
                        if self.player_nb == 1:
                            print(
                                f"[red] Player {self.player_nb} move: {cleaned_move_name}"
                            )
                            logger.info(f"[red] Player {self.player_nb} move: {cleaned_move_name}")
                        elif self.player_nb == 2:
                            print(
                                f"[green] Player {self.player_nb} move: {cleaned_move_name}"
                            )
                            logger.info(f"[green] Player {self.player_nb} move: {cleaned_move_name}")
                        valid_moves.append(cleaned_move_name)
                    else:
                        logger.debug(f"Invalid completion: {move}")
                        logger.info(f"Invalid completion: {move}")
                        logger.debug(f"Cleaned move name: {cleaned_move_name}")
                        invalid_moves.append(move)

                if len(invalid_moves) > 1:
                    logger.warning(f"Many invalid moves: {invalid_moves}")
                    logger.info(f"Many invalid moves: {invalid_moves}")

            logger.debug(f"Next moves: {valid_moves}")
            return valid_moves

        return []

    @abc.abstractmethod
    def call_llm(
        self,
        max_tokens: int = 100,
        top_p: float = 1.0,
    ) -> (
        Generator[ChatResponse, None, None] | Generator[CompletionResponse, None, None]
    ):
        """
        Make an API call to the language model.

        Edit this method to change the behavior of the robot!

        This should return a streaming response. The response should be a list of ChatResponse objects.
        Look into Llamaindex and make sure streaming is on.
        """
        raise NotImplementedError("call_llm method must be implemented")

    @abc.abstractmethod
    def observe(self, observation: dict, actions: dict, reward: float):
        """
        The robot will observe the environment by calling this method.

        The latest observations are at the end of the list.
        """
        # By default, we don't observe anything.
        pass


class TextRobot(Robot):
    def observe(self, observation: dict, actions: dict, reward: float):
        """
        The robot will observe the environment by calling this method.

        The latest observations are at the end of the list.
        """

        # detect the position of characters and ennemy based on color
        observation["character_position"] = detect_position_from_color(
            observation, self.character_color
        )
        observation["ennemy_position"] = detect_position_from_color(
            observation, self.ennemy_color
        )

        self.observations.append(observation)
        # we delete the oldest observation if we have more than 10 observations
        if len(self.observations) > 10:
            self.observations.pop(0)

        self.reward = reward

        if actions.get("agent_0") is not None and actions.get("agent_0") != 0:
            self.previous_actions["agent_0"].append(actions["agent_0"])
        if actions.get("agent_1") is not None and actions.get("agent_1") != 0:
            self.previous_actions["agent_1"].append(actions["agent_1"])

        for key, value in actions.items():
            if len(self.previous_actions[key]) > 10:
                self.previous_actions[key].pop(0)

        # Keep track of the current direction by checking the position of the character
        # and the ennemy
        character_position = observation.get("character_position")
        ennemy_position = observation.get("ennemy_position")
        if (
            character_position is not None
            and ennemy_position is not None
            and len(character_position) == 2
            and len(ennemy_position) == 2
        ):
            if character_position[0] < ennemy_position[0]:
                self.current_direction = "Right"
            else:
                self.current_direction = "Left"

    def context_prompt(self) -> str:
        """
        Return a str of the context

        "The observation for you is Left"
        "The observation for the opponent is Left+Up"
        "The action history is Up"
        """

        # Create the position prompt
        side = self.side
        obs_own = self.observations[-1]["character_position"]
        obs_opp = self.observations[-1]["ennemy_position"]
        super_bar_own = self.observations[-1]["P" + str(side + 1)]["super_bar"][0]

        if obs_own is not None and obs_opp is not None:
            relative_position = np.array(obs_own) - np.array(obs_opp)
            normalized_relative_position = [
                relative_position[0] / X_SIZE,
                relative_position[1] / Y_SIZE,
            ]
        else:
            normalized_relative_position = [0.3, 0]

        position_prompt = ""
        if abs(normalized_relative_position[0]) > 0.1:
            position_prompt += (
                "You are very far from the opponent. Move closer to the opponent."
            )
            if normalized_relative_position[0] < 0:
                position_prompt += "Your opponent is on the right."
            else:
                position_prompt += "Your opponent is on the left."

        else:
            position_prompt += "You are close to the opponent. You should attack him."

        power_prompt = ""
        if super_bar_own >= 30:
            power_prompt = "You can now use a powerfull move. The names of the powerful moves are: Megafireball, Super attack 2."
        if super_bar_own >= 120 or super_bar_own == 0:
            power_prompt = "You can now only use very powerfull moves. The names of the very powerful moves are: Super attack 3, Super attack 4"
        #### disable power_prompt
        power_prompt = ""
        # Create the last action prompt
        last_action_prompt = ""
        if len(self.previous_actions.keys()) >= 0:
            act_own_list = self.previous_actions["agent_" + str(side)]
            act_opp_list = self.previous_actions["agent_" + str(abs(1 - side))]

            if len(act_own_list) == 0:
                act_own = 0
            else:
                act_own = act_own_list[-1]
            if len(act_opp_list) == 0:
                act_opp = 0
            else:
                act_opp = act_opp_list[-1]

            str_act_own = INDEX_TO_MOVE[act_own]
            str_act_opp = INDEX_TO_MOVE[act_opp]

            last_action_prompt += f"Your last action was {str_act_own}. The opponent's last action was {str_act_opp}."

        reward = self.reward

        # Create the score prompt
        score_prompt = ""
        if reward > 0:
            score_prompt += "You are winning. Keep attacking the opponent."
        elif reward < 0:
            score_prompt += (
                "You are losing. Continue to attack the opponent but don't get hit."
            )

        # Assemble everything
        context = f"""{position_prompt}
{power_prompt}
{last_action_prompt}
Your current score is {reward}. {score_prompt}
To increase your score, move toward the opponent and attack the opponent. To prevent your score from decreasing, don't get hit by the opponent.
"""

        return context

    def call_llm(
        self,
        max_tokens: int = 100,
        top_p: float = 1.0,
    ) -> Generator[ChatResponse, None, None]:
        """
        Make an API call to the language model.

        Edit this method to change the behavior of the robot!
        """

        # Generate the prompts
        move_list = "- " + "\n - ".join([move for move in META_INSTRUCTIONS])
        system_prompt = f"""You are the best and most aggressive Street Fighter III 3rd strike player in the world.
Your character is {self.character}. Your goal is to beat the other opponent. You respond with a bullet point list of moves.
{self.context_prompt()}
The moves you can use are:
{move_list}
----
Reply with a bullet point list of moves. The format should be: `- <name of the move>` separated by a new line.
Example if the opponent is close:
- Move closer
- Medium Punch

Example if the opponent is far:
- Fireball
- Move closer"""

        # start_time = time.time()

        client = get_client(self.model, temperature=self.temperature)

        messages = [
            ChatMessage(role="system", content=system_prompt),
            ChatMessage(role="user", content="Your next moves are:"),
        ]

        resp = client.stream_chat(messages)

        # logger.debug(f"LLM call to {self.model}: {system_prompt}")
        # logger.debug(f"LLM call to {self.model}: {time.time() - start_time}s")

        return resp


class VisionRobot(Robot):
    def observe(self, observation: dict, actions: dict, reward: float):
        "We still use the same observation method to keep track of current characters direction"
        self.observations.append(observation)
        # we delete the oldest observation if we have more than 10 observations
        if len(self.observations) > 50:
            self.observations.pop(0)

        # detect the position of characters and ennemy based on color
        observation["character_position"] = detect_position_from_color(
            observation, self.character_color
        )
        observation["ennemy_position"] = detect_position_from_color(
            observation, self.ennemy_color
        )

        character_position = observation.get("character_position")
        ennemy_position = observation.get("ennemy_position")
        if (
            character_position is not None
            and ennemy_position is not None
            and len(character_position) == 2
            and len(ennemy_position) == 2
        ):
            if character_position[0] < ennemy_position[0]:
                self.current_direction = "Right"
            else:
                self.current_direction = "Left"

    def last_image_to_image_node(self) -> ImageNode:
        if len(self.observations) == 0:
            return ImageNode()

        rgb_array = self.observations[-1]["frame"]
        img = Image.fromarray(rgb_array)

        # Créer un buffer en mémoire
        buffer = io.BytesIO()

        # Sauvegarder l'image en format PNG dans le buffer
        img.save(buffer, format="PNG")

        # Obtenir les bytes de l'image encodée
        img_bytes = buffer.getvalue()

        # Create an ImageDocument
        return ImageNode(
            image=base64.b64encode(img_bytes).decode("utf-8"),
            image_mimetype="image/png",
        )

    def call_llm(
        self,
        max_tokens: int = 50,
        top_p: float = 1.0,
    ) -> Generator[CompletionResponse, None, None]:
        """
        Make an API call to the language model.

        Edit this method to change the behavior of the robot!
        """

        # Generate the prompts
        move_list = "- " + "\n - ".join([move for move in META_INSTRUCTIONS])
        system_prompt = f"""You are the best and most aggressive Street Fighter III 3rd strike player in the world.
Your character is {self.character}. Your goal is to beat the other opponent. You respond with a bullet point list of moves.

The current state of the game is given in the following image.

The moves you can use are:
{move_list}
----
Reply with a bullet point list of 3 moves. The format should be: `- <name of the move>` separated by a new line.
Example if the opponent is close:
- Move closer
- Medium Punch

Example if the opponent is far:
- Fireball
- Move closer"""

        start_time = time.time()

        client = get_client_multimodal(
            self.model, temperature=self.temperature
        )  # MultiModalLLM

        resp = client.stream_complete(
            prompt=system_prompt, image_documents=[self.last_image_to_image_node()]
        )

        logger.debug(f"LLM call to {self.model}: {system_prompt}")
        logger.debug(f"LLM call to {self.model}: {time.time() - start_time}s")

        return resp
    
    
    
class TextLocalRobot(TextRobot):
    def connect_socket(self):
        self.socket_config.listen()
    def get_moves_from_llm(self):
        invalid_moves = []
        valid_moves = []
        if os.getenv("DISABLE_LLM", "False") == "True":
            # Choose a random int from the list of moves
            logger.debug("DISABLE_LLM is True, returning a random move")
            return [random.choice(list(MOVES.values()))]
        while len(valid_moves) == 0:
            result = self.call_llm()
            print("result:", result)
            matches = re.findall(r"- ([\w ]+)", result)
            moves = ["".join(match) for match in matches]
            invalid_moves = []
            valid_moves = []
            for move in moves:
                cleaned_move_name = move.strip().lower()
                if cleaned_move_name in META_INSTRUCTIONS_WITH_LOWER.keys():
                    if self.player_nb == 1:
                        print(
                                f"[red] Player {self.player_nb} move: {cleaned_move_name}"
                            )
                    elif self.player_nb == 2:
                            print(
                                f"[green] Player {self.player_nb} move: {cleaned_move_name}"
                            )
                    valid_moves.append(cleaned_move_name)
                else:
                    logger.debug(f"Invalid completion: {move}")
                    logger.debug(f"Cleaned move name: {cleaned_move_name}")
                    invalid_moves.append(move)
                if len(invalid_moves) > 1:
                    logger.warning(f"Many invalid moves: {invalid_moves}")
                logger.debug(f"Next moves: {valid_moves}")
                logger.info(f"Next moves: {valid_moves}")
            return valid_moves

        return []
    def call_llm(self, max_token = None,top_p = None,temp = None,top_k = None):
        move_list = "- " + "\n - ".join([move for move in META_INSTRUCTIONS])
#         system_prompt = f"""You are the best and most aggressive Street Fighter III 3rd strike player in the world.
# Your character is {self.character}. Your goal is to beat the other opponent. You respond with a bullet point list of moves.

# The moves you can use are:
# {move_list}
# ----
# Reply with a bullet point list of moves. The format should be: `- <name of the move>` separated by a new line.
# Example answer if the opponent is close:
# - Move closer
# - Medium Punch

# Example answer if the opponent is far:
# - Fireball
# - Move closer
# {self.context_prompt()}
# Only generate the answer with the form of -XXXXX without additional thoughts."""
        if self.serving_method == "remote":     
            context = self.context_prompt()
            # system_prompt = system_prompt + context
            # context = self.context_prompt()
            system_prompt = self.character + "------" + move_list + "------" + context
        if self.serving_method == "local":
            input = self.tokenizer(system_prompt, return_tensors="pt")
            output = self.local_model.generate(**input, max_length=max_token, top_p=top_p, temperature=temp, top_k=top_k, do_sample=True)
            generate_tokens = output.sequences[:, input["input_ids"].shape[-1]:]
            text = self.tokenizer.decode(generate_tokens[0], skip_special_tokens=True)
            logger.debug(f"LLM call to {self.model}: {system_prompt}")
            logger.debug(f"LLM result: {text}")
            
        elif self.serving_method == "remote":
            # send the system prompt to the socket client
            self.socket_config.conn.sendall(system_prompt.encode())
            # receive the result from the socket server
            text = self.socket_config.conn.recv(4096).decode()
            logger.debug(f"LLM result: {text}")
        logger.info(f"context prompt: {context}")
        logger.info(f"LLM result: {text}")
        return text
    
    
        
        
        
    