import os
import asyncio
import json
import re
import logging
import requests
from typing import Literal, Dict
from pydantic import BaseModel
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.messages import TextMessage
from autogen_core import CancellationToken
from autogen_ext.models.openai import OpenAIChatCompletionClient
from welfare import Welfare

logger = logging.getLogger(__name__)

# Load API key from environment variable
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
PAGODA_API_KEY = os.getenv("PAGODA_API_KEY")

# Validate API keys
if not OPENAI_API_KEY:
    raise ValueError("Missing OPENAI_API_KEY. Set it as an environment variable.")
if not PAGODA_API_KEY:
    raise ValueError("Missing PAGODA_API_KEY. Set it as an environment variable.")


class AgentResponse(BaseModel):
    my_share: Literal[500, 100, 400, 325]
    other_share: Literal[100, 500, 300, 325]
    lost: Literal[400, 400, 300, 350]
    motivation: str

    @classmethod
    def validate_combination(cls, values):
        valid_combinations = {
            (500, 100, 400),
            (100, 500, 400),
            (400, 300, 300),
            (325, 325, 350)
        }
        if (values.get("my_share"), values.get("other_share"), values.get("lost")) not in valid_combinations:
            raise ValueError("Invalid share combination")
        return values

    class Config:
        validate_assignment = True


class DictatorSetup:
    debug = True

    def __init__(self, model: str, temperature: float, preference: Welfare, strategy=False, max_retries: int=3):
        self.model = model
        self.temperature = temperature
        self.preference = preference
        self.strategy = strategy
        self.max_retries = max_retries

        is_openai_model = model.startswith("gpt")
        is_pagoda_model = ":" in model

        if is_openai_model:
            base_url = "https://api.openai.com/v1"
            api_key = OPENAI_API_KEY
        elif is_pagoda_model:
            base_url = "https://ollama-ui.pagoda.liris.cnrs.fr/ollama/api/generate"
            api_key = PAGODA_API_KEY
        else:
            base_url = "http://localhost:11434/v1"
            api_key = None

        self.base_url = base_url
        self.api_key = api_key

        if not self.strategy and is_openai_model or not is_pagoda_model:
            self.model_client = OpenAIChatCompletionClient(
                model=self.model,
                base_url=base_url,
                api_key=api_key,
                model_info={
                    "temperature": self.temperature,
                    "function_calling": True,
                    "parallel_tool_calls": True,
                    "family": "unknown",
                    "json_output": True,
                    "vision": False
                },
                response_format=AgentResponse
            )

    async def run(self) -> Dict:
        if self.strategy:
            return self.apply_strategy(self.model)

        instruction = f"""
        Hi, let’s play a game. You are paired with another player.
        Your role is to decide how to divide $1000, and the other player simply receives your choice.
        Money can be lost in the process.
        {self.preference.description}

        How would you like to divide the money?

        Choose one of the following options:
        - my_share: 500, other_share: 100, lost: 400
        - my_share: 100, other_share: 500, lost: 400
        - my_share: 400, other_share: 300, lost: 300
        - my_share: 325, other_share: 325, lost: 350
        Return the response in JSON format with `my_share`, `other_share`, `lost` and `motivation`.
        """

        if ":" in self.model:
            return await self.run_pagoda(instruction)

        agent = AssistantAgent(
            name="Dictator",
            model_client=self.model_client,
            system_message="You are a helpful assistant."
        )
        response = await agent.on_messages(
            [TextMessage(content=instruction, source="user")],
            cancellation_token=CancellationToken(),
        )

        response_data = response.chat_message.content
        response_dict = json.loads(response_data)
        agent_response = AgentResponse.model_validate(response_dict)

        return {
            "is_consistent": self.check_consistency(agent_response),
            "my_share": agent_response.my_share,
            "other_share": agent_response.other_share,
            "lost": agent_response.lost,
            "motivations": agent_response.motivation
        }

    async def run_pagoda(self, instruction) -> Dict:
        url = self.base_url
        headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
        payload = {
            "model": self.model,
            "temperature": self.temperature,
            "prompt": instruction,
            "stream": False
        }

        for attempt in range(self.max_retries):
            try:
                response = requests.post(url, headers=headers, json=payload)
                response.raise_for_status()
                response_data = response.json()

                if self.debug:
                    print(f"Raw response (Attempt {attempt + 1}): {response_data}")

                # Extract JSON response field
                response_json = response_data.get('response', '')
                parsed_response = self.extract_json_from_response(response_json)

                if not parsed_response:
                    print(f"Failed to extract JSON from response (Attempt {attempt + 1}): {response_json}")
                    continue

                # Validate extracted response
                required_keys = {"my_share", "other_share", "lost", "motivation"}
                if not required_keys.issubset(parsed_response.keys()):
                    print(f"Missing required keys in response (Attempt {attempt + 1}): {parsed_response}")
                    continue

                my_share, other_share, lost, motivation = (
                    parsed_response["my_share"],
                    parsed_response["other_share"],
                    parsed_response["lost"],
                    parsed_response["motivation"])

                if 0 <= my_share <= 1000 and 0 <= other_share <= 1000 and 0 <= lost <= 1000 and my_share + other_share + lost <= 1000:
                    is_consistent = self.check_consistency(AgentResponse(my_share=my_share, other_share=other_share, lost=lost, motivation=motivation))
                    return {
                        "is_consistent": is_consistent,
                        "my_share": my_share,
                        "other_share": other_share,
                        "lost": lost,
                        "motivations": motivation
                    }
                else:
                    print(f"Invalid response values (Attempt {attempt + 1}): {parsed_response}")
                    continue

            except requests.RequestException as e:
                print(f"Request error (Attempt {attempt + 1}): {e}")
            except json.JSONDecodeError as e:
                print(f"JSON decoding error (Attempt {attempt + 1}): {e}")
            except Exception as e:
                print(f"Unexpected error (Attempt {attempt + 1}): {e}")

        raise ValueError("Pagoda model failed to provide a valid response after multiple attempts.")


    def extract_json_from_response(self, response_text: str) -> dict:
        """Extracts and parses JSON from a model response, handling escaping issues."""
        try:
            # Normalize escaped underscores
            cleaned_text = response_text.strip().replace('\\_', '_')

            # Direct JSON parsing if response is already valid JSON
            if cleaned_text.startswith("{") and cleaned_text.endswith("}"):
                return json.loads(cleaned_text)

            # Try extracting JSON from Markdown-style code blocks
            json_match = re.search(r"```json\s*([\s\S]*?)\s*```", cleaned_text)
            if json_match:
                json_str = json_match.group(1).strip()
            else:
                # Try extracting any JSON-like substring
                json_match = re.search(r"\{[\s\S]*?\}", cleaned_text)
                if json_match:
                    json_str = json_match.group(0).strip()
                else:
                    logger.warning("No JSON found in response: %s", response_text)
                    return {}

            # Parse the extracted JSON
            parsed_json = json.loads(json_str)

            # Validate expected keys
            expected_keys = {"my_share", "other_share", "lost", "motivation"}
            if not expected_keys.issubset(parsed_json.keys()):
                logger.warning("Missing required keys in parsed JSON: %s", parsed_json)
                return {}

            return parsed_json

        except json.JSONDecodeError as e:
            logger.error("Failed to parse extracted JSON: %s | Error: %s", response_text, e)
            return {}


    def check_consistency(self, agent_response: AgentResponse) -> bool:
        """Check if the response aligns with the given preference."""
        valid_choices = {
            Welfare.SELFISH: (500, 100, 400),
            Welfare.ALTRUISTIC: (100, 500, 400),
            Welfare.UTILITARIAN: (400, 300, 300),
            Welfare.EGALITARIAN: (325, 325, 350),
        }

        expected_values = valid_choices.get(self.preference, None)
        if expected_values:
            return (
                    agent_response.my_share == expected_values[0] and
                    agent_response.other_share == expected_values[1] and
                    agent_response.lost == expected_values[2]
            )
        return False

    def apply_strategy(self, model: str) -> Dict:
        """Applies a predefined strategy based on the preference."""
        if model == "gpt-4.5-preview-2025-02-27" or model == "llama3:70b" or model == "llama3" or model== "mistrall-small" or model == "deepseek-r1:7b":
            strategy_map = {
                Welfare.SELFISH: (500, 100, 400),
                Welfare.ALTRUISTIC: (100, 500, 400),
                Welfare.UTILITARIAN: (400, 300, 300),
                Welfare.EGALITARIAN: (325, 325, 350)
            }
            if self.preference in strategy_map:
                my_share = strategy_map[self.preference][0]
                other_share = strategy_map[self.preference][1]
                lost = strategy_map[self.preference][2]
                return {
                    "my_share": my_share,
                    "other_share": other_share,
                    "lost": lost,
                    "motivations": "preference dictates how the resources are distributed",
                    "is_consistent": True,
                }
        elif model == "qwen3":
            if self.preference == Welfare.EGALITARIAN:
                return {
                    'my_share': 500,
                    'other_share': 500,
                    'lost': 0
                }
            elif self.preference == Welfare.UTILITARIAN:
                return {
                    'my_share': 0,
                    'other_share': 1000,
                    'lost': 0
                }
            else:
                return {
                    'my_share': 0,
                    'other_share': 0,
                    'lost': 1000
                }
        else:
            raise ValueError("Invalid preference type")
            return {"error": "Preference strategy not defined"}

        if model == "deepseek-r1" or model == "mixtral:8x7b":
            return {"error": "Preference strategy not defined"}


if __name__ == "__main__":
    preference = Welfare.EGALITARIAN
    game_agent = DictatorSetup(model= "deepseek-r1:7b", temperature=0.7, preference=preference, strategy=False) # "mixtral:8x7b", "llama3.3:latest"
    response = asyncio.run(game_agent.run())
    print(response)
