import os
import asyncio
from typing import Dict, Literal
import json
import re
import logging
import requests
from pydantic import BaseModel
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.messages import TextMessage
from autogen_core import CancellationToken
from autogen_ext.models.openai import OpenAIChatCompletionClient
from src.pd.role import Role

logger = logging.getLogger(__name__)

# Load API keys from environment variables
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
PAGODA_API_KEY = os.getenv("PAGODA_API_KEY")
if not OPENAI_API_KEY:
    raise ValueError("Missing OPENAI_API_KEY. Set it as an environment variable.")
if not PAGODA_API_KEY:
    raise ValueError("Missing PAGODA_API_KEY. Set it as an environment variable.")

#
# Agent response format
class AgentResponse(BaseModel):
    action: Literal["Foo", "Bar"]
    reasoning: str

class PD:
    def __init__(
        self,
        model: str,
        version: str,
        role: Role,
        anonymized: bool = False,
        temperature: float = 0.7,
        strategy: bool = False,
        max_retries: int = 3
    ):
        self.model = model
        self.version = version
        self.role = role
        self.anonymized = anonymized
        self.temperature = temperature
        self.strategy = strategy
        self.max_retries = max_retries
        # Set labels for cooperate and defect actions
        self.Cooperate, self.Defect = "Foo", "Bar"
        if anonymized:
            self.CooperateDescription, self.DefectDescription = "", ""
        else:
            self.CooperateDescription, self.DefectDescription = "(Cooperate)", "(Defect)"

        self.debug = False
        is_openai_model = model.startswith("gpt")
        is_pagoda_model = ":" in model
        self.base_url = (
            "https://api.openai.com/v1" if is_openai_model else
            "https://ollama-ui.pagoda.liris.cnrs.fr/ollama/api/generate" if is_pagoda_model else
            "http://localhost:11434/v1"
        )
        key = OPENAI_API_KEY if is_openai_model else PAGODA_API_KEY
        model_info = {
            "temperature": self.temperature,
            "function_calling": True,
            "parallel_tool_calls": True,
            "family": "unknown",
            "json_output": True,
            "vision": False
        }
        self.model_client = OpenAIChatCompletionClient(
            model=self.model,
            base_url=self.base_url,
            api_key=OPENAI_API_KEY,
            model_info=model_info,
            response_format=AgentResponse
        )

    async def run(self) -> Dict:
        if self.strategy:
            return self.apply_strategy()
        instruction = self.build_instruction()
        if self.debug:
            print(instruction)
        is_pagoda_model = ":" in self.model
        if is_pagoda_model:
            return await self.run_pagoda(instruction)
        for attempt in range(self.max_retries):
            agent = AssistantAgent(
                name="Player",
                model_client=self.model_client,
                system_message="You are an assistant agent"
            )
            response = await agent.on_messages(
                [TextMessage(content=instruction, source="user")],
                cancellation_token=CancellationToken(),
            )
            try:
                response_data = response.chat_message.content
                agent_response = AgentResponse.model_validate_json(response_data)
                rational = 1.0 if agent_response.action == self.Defect else 0.0
                return {
                    "action": agent_response.action,
                    "rationality": rational,
                    "reasoning": agent_response.reasoning
                }
            except Exception as e:
                print(f"Error parsing response (Attempt {attempt+1}): {e}")
        raise ValueError("Model failed to provide a valid response after multiple attempts.")

    def build_instruction(self) -> str:
        T, R, P, S =float('nan'), float('nan'), float('nan'), float('nan')
        if self.version == "classic":
            T, R, P, S = 5, 3, 1, 0  # Standard PD values
        elif self.version == "high":
            T, R, P, S = 10, 6, 2, 1  # Defecting is highly rewarding; increased gap between R and P
        elif self.version == "mild":
            T, R, P, S = 3, 2.5, 2, 1  # Softened outcomes; smaller differences
        elif self.version == "cooperation_loss":
            T, R, P, S = 8, 6, 2, -3  # Emphasized betrayal cost; large gap between S and R
        else:
            raise ValueError(f"Unknown version: {self.version}")

        # Validate Prisoner's Dilemma payoff structure
        if not (T > R > P > S):
            raise ValueError(f"Invalid payoff ordering: expected T > R > P > S, got T={T}, R={R}, P={P}, S={S}")

        rules = f"""
        We are playing a single-round symmetric 2x2 game.
        
        ### Game Rules:
        You can choose the actions **{self.Cooperate}** {self.CooperateDescription} or **{self.Defect}** {self.DefectDescription}.
        - If you choose {self.Cooperate} and the other player chooses {self.Cooperate}, you receive {R} points and the other player receives {R} points.
        - If you choose {self.Cooperate} and the other player chooses {self.Defect}, you receive {S} points and the other player receives {T} points.
        - If you choose {self.Defect} and the other player chooses {self.Cooperate}, you receive {T} points and the other player receives {S} points.
        - If you choose {self.Defect} and the other player chooses {self.Defect}, you receive {P} points and the other player receives {P} points.
        """

        role = f"""
        ### Your Role:
        {self.role}
        """

        task = """
        ### Your Task:
        Decide your move and explain it. Respond in JSON format:
        {
            "action": "Foo" or "Bar",
            "reasoning": "your reasoning here"
        }
        """
        return rules + role + task


    def apply_strategy(self) -> Dict:
        """Hard-coded strategy: ."""
        if self.model == "gpt-4.5-preview-2025-02-27":
            if (R - P) > 1 and (S >= 0):  # Favoring cooperation in milder scenarios
                action = self.Cooperate
                rationality = False
                reasoning = (
                    f"I chose {self.Cooperate} because the reward for mutual cooperation (R={R}) is significantly better than mutual defection (P={P}), "
                    f"and the risk of being betrayed (S={S}) is tolerable."
                )
            else:
                action = self.Defect
                rationality = True
                reasoning = (
                    f"I chose {self.Defect} because the temptation payoff (T={T}) and punishment (P={P}) make it more advantageous or safer than risking betrayal (S={S})."
                )
            return {
                "action": action,
                "rationality": rationality,
                "reasoning": reasoning
            }
        if self.model == "mistral-small" or "qwen3" or "llama3.3:latest" or "mixtral:8x7b" :
            return None
        if self.model == "llama3" or self.model == "deepseek-r1" or self.model == "qwen3":
            return {
                "action": self.Cooperate,
                "rationality": False,
                "reasoning": "I'm playing fairly"
            }
        return None

    async def run_pagoda(self, instruction) -> Dict:
        url = self.base_url
        headers = {"Authorization": f"Bearer {PAGODA_API_KEY}", "Content-Type": "application/json"}
        payload = {
            "model": self.model,
            "temperature": self.temperature,
            "messages": [
                {
                    "role": "system",
                    "content": "You are an assistant agent"
                 }
                ],
            "prompt": instruction,
            "stream": False
        }

        for attempt in range(self.max_retries):
            try:
                response = requests.post(url, headers=headers, json=payload)
                response.raise_for_status()
                response_data = response.json()

                if self.debug:
                    print(f"Raw response (Attempt {attempt + 1}): {response_data}")

                # Extract JSON response field
                response_json = response_data.get('response', '')
                parsed_response = self.extract_json_from_response(response_json)

                if not parsed_response:
                    print(f"Failed to extract JSON from response (Attempt {attempt + 1}): {response_json}")
                    continue

                # Validate extracted response
                required_keys = {'action', 'reasoning'}
                if not required_keys.issubset(parsed_response.keys()):
                    print(f"Missing required keys in response (Attempt {attempt + 1}): {parsed_response}")
                    continue

                action, reasoning = (
                    parsed_response["action"],
                    parsed_response["reasoning"]
                )
                rational = 1.0 if action == self.Defect else 0.0
                return {
                    "action": action,
                    "rationality": rational,
                    "reasoning": reasoning
                }
            except requests.RequestException as e:
                print(f"Request error (Attempt {attempt + 1}): {e}")
            except json.JSONDecodeError as e:
                print(f"JSON decoding error (Attempt {attempt + 1}): {e}")
            except Exception as e:
                print(f"Unexpected error (Attempt {attempt + 1}): {e}")

        raise ValueError("Pagoda model failed to provide a valid response after multiple attempts.")


    def extract_json_from_response(self, response_text: str) -> dict:
        """Extracts and parses JSON from a model response, handling escaping issues."""
        try:
            # Normalize escaped underscores
            cleaned_text = response_text.strip().replace('\\_', '_')

            # Direct JSON parsing if response is already valid JSON
            if cleaned_text.startswith("{") and cleaned_text.endswith("}"):
                return json.loads(cleaned_text)

            # Try extracting JSON from Markdown-style code blocks
            json_match = re.search(r"```json\s*([\s\S]*?)\s*```", cleaned_text)
            if json_match:
                json_str = json_match.group(1).strip()
            else:
                # Try extracting any JSON-like substring
                json_match = re.search(r"\{[\s\S]*?\}", cleaned_text)
                if json_match:
                    json_str = json_match.group(0).strip()
                else:
                    logger.warning("No JSON found in response: %s", response_text)
                    return {}

            # Parse the extracted JSON
            parsed_json = json.loads(json_str)

            # Validate expected keys
            expected_keys = {"action", "reasoning"}
            if not expected_keys.issubset(parsed_json.keys()):
                logger.warning("Missing required keys in parsed JSON: %s", parsed_json)
                return {}

            return parsed_json

        except json.JSONDecodeError as e:
            logger.error("Failed to parse extracted JSON: %s | Error: %s", response_text, e)
            return {}


# Example usage
if __name__ == "__main__":
    T, R, P, S = 5, 3, 1, 0  # Classic Prisoner's Dilemma payoffs
    pd = PD(
        model="mixtral:8x7b",
        version="classic",
        temperature=0.7,
        role=Role.RATIONAL,
        anonymized= False,
        strategy = False
    )
    # "gpt-4.5-preview-2025-02-27", "llama3", "mistral-small", "deepseek-r1", "qwen3", "llama3.3:latest", "deepseek-r1:7b", "mixtral:8x7b", "qwen3:32b"
    result = asyncio.run(pd.run())
    print(result)