import os
import asyncio
from typing import Dict, Literal
import json
import random
import re
import logging
import requests
from pydantic import BaseModel
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.messages import TextMessage
from autogen_core import CancellationToken
from autogen_ext.models.openai import OpenAIChatCompletionClient
from belief import Belief

logger = logging.getLogger(__name__)

# Load API keys from environment variables
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
PAGODA_API_KEY = os.getenv("PAGODA_API_KEY")
if not OPENAI_API_KEY:
    raise ValueError("Missing OPENAI_API_KEY. Set it as an environment variable.")
if not PAGODA_API_KEY:
    raise ValueError("Missing PAGODA_API_KEY. Set it as an environment variable.")

# Define the expected response format as a Pydantic model
class AgentResponse(BaseModel):
    action: Literal["A", "B", "X", "Y"]
    reasoning: str

# The ring game2x2 simulation class
class Ring:
    debug=False

    def __init__(self, player_id: int, belief: Belief, use_conditional_reasoning : bool, swap: bool, version: str, model: str, temperature: float, strategy = False, max_retries: int = 3):
        self.player_id = player_id
        self.other_player_id = "2" if player_id == 1 else "1"
        self.belief = belief
        self.use_conditional_reasoning = use_conditional_reasoning
        self.swap = swap
        self.A, self.B, self.X, self.Y = ("B", "A", "Y", "X") if swap else ("A", "B", "X", "Y")
        self.actions = [self.X, self.Y] if self.player_id == 1 else [self.A, self.B]
        self.other_actions = [self.A, self.B] if self.player_id == 1 else [self.X, self.Y]
        self.version = version
        self.model = model
        self.temperature = temperature
        self.strategy = strategy
        self.max_retries = max_retries  # Maximum retry attempts in case of hallucinations

        is_openai_model = model.startswith("gpt")
        is_pagoda_model = ":" in model

        self.base_url = (
            "https://api.openai.com/v1" if is_openai_model else
            "https://ollama-ui.pagoda.liris.cnrs.fr/ollama/api/generate" if is_pagoda_model else
            "http://localhost:11434/v1"
        )

        key = OPENAI_API_KEY if is_openai_model else PAGODA_API_KEY

        model_info = {
            "temperature": self.temperature,
            "function_calling": True,
            "parallel_tool_calls": True,
            "family": "unknown",
            "json_output": True,
            "vision": False
        }

        self.model_client = OpenAIChatCompletionClient(
            model=self.model,
            base_url=self.base_url,
            api_key=OPENAI_API_KEY,
            model_info=model_info,
            response_format=AgentResponse
        )

    async def run(self) -> Dict:
        """Runs the model and ensures a valid response."""
        if self.strategy:
            return self.apply_strategy()
        # (Rest of the method continues using the model-based approach)
        action_description = (
            ' - `"action"`: Your move ("A" or "B")' if self.player_id == 2
            else ' - `"action"`: Your move ("X" or "Y")'
        )

        XknowingA, XknowingB, YknowingA, YknowingB = (
            (15, 5, 0, 10) if self.version == "a" else
            (8, 7, 7, 8) if self.version == "b" else
            (6, 5, 0, 10) if self.version == "c" else
            (15, 5, 0, 40)
        )

        implicit_belief = f"""
        - If Player 1 chooses {self.X} and Player 2 chooses {self.A}, Player 1 receives {XknowingA} points, and Player 2 receives 10 points.
        - If Player 1 chooses {self.X} and Player 2 chooses {self.B}, Player 1 gets {XknowingB} points, while Player 2 also gets 5 points.
        - If Player 1 chooses {self.Y} and Player 2 chooses {self.A}, Player 1 receives {YknowingA} points, while Player 2 gets 5 points.
        - If Player 1 chooses {self.Y} and Player 2 chooses {self.B}, Player 1 receives {YknowingB} points, and Player 2 gets 0 points.
        """
        explicit_belief = f"""
        {implicit_belief}
        {self.B} is strictly dominated by {self.A}
        """
        given_belief = f"""
        {explicit_belief}
        Player 2 must choose {self.A} if she is rational
        """

        rules = (
            given_belief if self.belief == Belief.GIVEN
            else explicit_belief if self.belief == Belief.EXPLICIT
            else implicit_belief
        )

        conditional_reasoning = f"""
        ### **Conditional Reasoning:**    
        Suppose you have chosen to take the action: {self.actions[0]}.
        There are two possible actions that Player {self.other_player_id} can take: {self.other_actions}
        What action do you think Player {self.other_player_id} will take to maximize their reward and why?
        
        Suppose you have chosen to take the action: {self.actions[1]}.
        There are two possible actions that Player {self.other_player_id} can take: {self.other_actions}
        What action do you think {self.other_player_id} will take to maximize their reward and why?
        
        Analyze the game2x2 rules based on your choice, and then infer their optimal choice that maximizes their reward conditionally.
        """ if self.use_conditional_reasoning else ""

        instruction = f"""
        We are playing a simple game2x2 designed to evaluate your ability to play rationally.
        Your aim is to maximize your own payoff.
        You are Player {self.player_id}.

        ### **Game Rules:**        
        {rules}

        {conditional_reasoning}

        ### **Your Task:**  
        Return your response in JSON format with two keys:  
        {action_description}
        - `"reasoning"`: A brief explanation of how you made your choice.
        """

        if self.debug:
            print(instruction)

        is_pagoda_model = ":" in self.model
        if is_pagoda_model:
            return await self.run_pagoda(instruction)

        for attempt in range(self.max_retries):
            agent = AssistantAgent(
                name="Player",
                model_client=self.model_client,
                system_message="You are a helpful assistant."
            )

            response = await agent.on_messages(
                [TextMessage(content=instruction, source="user")],
                cancellation_token=CancellationToken(),
            )
            try:
                response_data = response.chat_message.content
                agent_response = AgentResponse.model_validate_json(response_data)  # Parse JSON
                action, reasoning = agent_response.action, agent_response.reasoning
                # Validate values
                if self.player_id == 2 and (action == self.A or action == self.B) or (self.player_id == 1 and (action == self.X or action == self.Y)):
                    rational = 1.0 if self.check_rationality(agent_response) else 0.0
                    return {
                        "action": agent_response.action,
                        "rationality": rational,
                        "reasoning": agent_response.reasoning
                    }
                else:
                    print(f"Invalid response detected (Attempt {attempt+1}): {response_data}")
            except Exception as e:
                print(f"Error parsing response (Attempt {attempt+1}): {e}")
        raise ValueError("Model failed to provide a valid response after multiple attempts.")


    def check_rationality(self, agent_response: AgentResponse) -> bool:
        """Check if the response is rational."""
        if self.player_id == 2:
            return agent_response.action == self.A
        else:
            return agent_response.action == self.X

    def apply_strategy(self) -> Dict[str, str]:
        """Applies a heuristic-based strategy instead of relying on the model if strategy is enabled."""
        # Set default values to avoid unbound variable errors
        action = "X"  # Default action (can be changed based on conditions)
        reasoning = "Default reasoning. No specific model-based rule applied."
        if self.model == "gpt-4.5-preview-2025-02-27":
            if self.strategy:
                if self.player_id == 2:
                    action = self.A  # Always choose A, as B is strictly dominated
                    reasoning = f"Choosing {self.A} because {self.B} is strictly dominated and rational players avoid dominated strategies."
                else:
                    action = self.X if self.version in ["a", "c", "d"] else self.Y
                    reasoning = f"Choosing {action} based on the given game2x2 structure and expected rational behavior from Player 2."
        if self.model == "llama3.3:latest":
            XknowingA, XknowingB, YknowingA, YknowingB = (
                (15, 5, 0, 10) if self.version == "a" else
                (8, 7, 7, 8) if self.version == "b" else
                (6, 5, 0, 10) if self.version == "c" else
                (15, 5, 0, 40)
            )
            if self.belief == Belief.IMPLICIT:
                if self.player_id == 1:
                    action = self.X if random.random() < 0.5 else self.Y
                    reasoning = "Choosing randomly between X and Y since it's an implicit game2x2."
                elif self.player_id == 2:
                    action = self.A if random.random() < 0.5 else self.B
                    reasoning = "Choosing randomly between A and B since it's an implicit game2x2."
            elif self.belief == Belief.EXPLICIT:
                if self.player_id == 1:
                    action = self.X if XknowingA > YknowingA else self.Y
                    reasoning = f"Choosing {action} since it has a higher payoff ({XknowingA} vs {YknowingA})."
                elif self.player_id == 2:
                    action = self.A if XknowingA + YknowingB > XknowingB + YknowingA else self.B
                    reasoning = f"Choosing {action} since it has a higher total payoff ({XknowingA + YknowingB} vs {XknowingB + YknowingA})."
            if self.belief == Belief.GIVEN:
                if self.player_id == 1:
                    action = self.X
                    reasoning = "Choosing X since Player 2 must choose A if she is rational."
                elif self.player_id == 2:
                    action = self.A
                    reasoning = "Choosing A since I am rational and it's the dominant strategy."
        if self.model == "llama3":
            if self.player_id == 1:
                action = self.X if random.random() < 0.5 else self.Y
                reasoning = "The reasoning behind this choice is..."
            elif self.player_id == 2:
                action = self.B if random.random() < 0.5 else self.A
                reasoning = "The reasoning behind this choice is..."
        if self.model == "mistral-small" or self.model == "mixtral:8x7b":
            #Always choose 'A' or 'X' based on player_id
            if self.player_id == 1:
                action = self.X
                reasoning = f"Player {self.player_id} always chooses X as per the predefined strategy."
            elif self.player_id == 2:
                action = self.A
                reasoning = f"Player {self.player_id} always chooses B as per the predefined strategy."
        if self.model == "qwen3":
            if self.player_id == 1:
                action = self.Y
                reasoning = f"Player {self.player_id} always chooses Y as per the predefined strategy."
            elif self.player_id == 2:
                action = self.B
                reasoning = f"Player {self.player_id} always chooses B as per the predefined strategy."
        if self.model == "deepseek-r1:7b" or self.model == "deepseek-r1":
            raise ValueError("Invalid strategy for deepseek-r1.")
        # Validate the rationality of the chosen action
        rational = 1.0 if self.check_rationality(AgentResponse(action=action, reasoning=reasoning)) else 0.0
        return {
            "action": action,
            "rationality": rational,
            "reasoning": reasoning
        }

    async def run_pagoda(self, instruction) -> Dict:
        url = self.base_url
        headers = {"Authorization": f"Bearer {PAGODA_API_KEY}", "Content-Type": "application/json"}
        payload = {
            "model": self.model,
            "temperature": self.temperature,
            "prompt": instruction,
            "stream": False
        }

        for attempt in range(self.max_retries):
            try:
                response = requests.post(url, headers=headers, json=payload)
                response.raise_for_status()
                response_data = response.json()

                if self.debug:
                    print(f"Raw response (Attempt {attempt + 1}): {response_data}")

                # Extract JSON response field
                response_json = response_data.get('response', '')
                parsed_response = self.extract_json_from_response(response_json)

                if not parsed_response:
                    print(f"Failed to extract JSON from response (Attempt {attempt + 1}): {response_json}")
                    continue

                # Validate extracted response
                required_keys = {'action', 'reasoning'}
                if not required_keys.issubset(parsed_response.keys()):
                    print(f"Missing required keys in response (Attempt {attempt + 1}): {parsed_response}")
                    continue

                action, reasoning = (
                    parsed_response["action"],
                    parsed_response["reasoning"]
                )
                rational = 1.0 if self.check_rationality(AgentResponse(action=action, reasoning=reasoning)) else 0.0
                return {
                    "action": action,
                    "rationality": rational,
                    "reasoning": reasoning
                }
            except requests.RequestException as e:
                print(f"Request error (Attempt {attempt + 1}): {e}")
            except json.JSONDecodeError as e:
                print(f"JSON decoding error (Attempt {attempt + 1}): {e}")
            except Exception as e:
                print(f"Unexpected error (Attempt {attempt + 1}): {e}")

        raise ValueError("Pagoda model failed to provide a valid response after multiple attempts.")


    def extract_json_from_response(self, response_text: str) -> dict:
        """Extracts and parses JSON from a model response, handling escaping issues."""
        try:
            # Normalize escaped underscores
            cleaned_text = response_text.strip().replace('\\_', '_')

            # Direct JSON parsing if response is already valid JSON
            if cleaned_text.startswith("{") and cleaned_text.endswith("}"):
                return json.loads(cleaned_text)

            # Try extracting JSON from Markdown-style code blocks
            json_match = re.search(r"```json\s*([\s\S]*?)\s*```", cleaned_text)
            if json_match:
                json_str = json_match.group(1).strip()
            else:
                # Try extracting any JSON-like substring
                json_match = re.search(r"\{[\s\S]*?\}", cleaned_text)
                if json_match:
                    json_str = json_match.group(0).strip()
                else:
                    logger.warning("No JSON found in response: %s", response_text)
                    return {}

            # Parse the extracted JSON
            parsed_json = json.loads(json_str)

            # Validate expected keys
            expected_keys = {"action", "reasoning"}
            if not expected_keys.issubset(parsed_json.keys()):
                logger.warning("Missing required keys in parsed JSON: %s", parsed_json)
                return {}

            return parsed_json

        except json.JSONDecodeError as e:
            logger.error("Failed to parse extracted JSON: %s | Error: %s", response_text, e)
            return {}


# Run the async function and return the response
if __name__ == "__main__":
    game_agent = Ring(1, Belief.EXPLICIT, use_conditional_reasoning=True, swap = False, version="a", model="qwen3", temperature=0.7, strategy = False)#  "llama3.3:latest", "mixtral:8x7b", "deepseek-r1:7b"
    response_json = asyncio.run(game_agent.run())
    print(response_json)