import os
import asyncio
import csv
import random
import json
import re
import requests
from typing import Dict, Literal, List, Callable
from pydantic import BaseModel, ValidationError
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.messages import TextMessage
from autogen_core import CancellationToken
from autogen_ext.models.openai import OpenAIChatCompletionClient


# Load API keys from environment variables
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
PAGODA_API_KEY = os.getenv("PAGODA_API_KEY")
if not OPENAI_API_KEY:
    raise ValueError("Missing OPENAI_API_KEY. Set it as an environment variable.")
if not PAGODA_API_KEY:
    raise ValueError("Missing PAGODA_API_KEY. Set it as an environment variable.")

CSV_FILE_PATH = "../../data/guess/guess.csv"

# Define the expected response format as a Pydantic model
class AgentResponse(BaseModel):
    prediction: Literal["Rock", "Paper", "Scissors"]
    reasoning: str

# Define Guess simulation class
class Guess:
    def __init__(self, model: str, temperature: float, game_id: int, opponent_strategy_fn: Callable[[List[Dict]], str], strategy=False, max_retries: int = 3):
        self.debug = False
        self.model = model
        self.temperature = temperature
        self.game_id = game_id
        self.max_retries = max_retries
        self.history: List[Dict] = []
        self.player_score_game = 0
        self.opponent_strategy_fn = opponent_strategy_fn
        self.strategy = strategy  # Determines whether to use a model or a rule-based method

        is_openai_model = model.startswith("gpt")
        is_pagoda_model = ":" in model

        self.base_url = (
            "https://api.openai.com/v1" if is_openai_model else
            "https://ollama-ui.pagoda.liris.cnrs.fr/ollama/api/generate" if is_pagoda_model else
            "http://localhost:11434/v1"
        )

        model_info = {
            "temperature": self.temperature,
            "function_calling": True,
            "parallel_tool_calls": True,
            "family": "unknown",
            "json_output": True,
            "vision": False
        }

        self.model_client = OpenAIChatCompletionClient(
            model=self.model,
            base_url=self.base_url,
            api_key=OPENAI_API_KEY,
            model_info=model_info,
            response_format=AgentResponse
        )


    async def play_round(self, round_id: int) -> Dict:
        """Plays a single round of Guess The Next Move."""
        opponent_move = self.opponent_strategy_fn(self.history)

        if self.strategy:
            # Use a simple rule-based prediction
            move, reasoning = self.apply_strategy()
        else:
            # Use the model for prediction
            move, reasoning = await self.model_based_prediction()

        outcome = self.determine_accuracy(move, opponent_move)
        self.update_score(outcome)
        round_result = {
            "Agent Prediction": move,
            "Opponent Move": opponent_move,
            "Reasoning": reasoning,
            "Outcome": outcome
        }
        self.history.append(round_result)
        return round_result

    async def model_based_prediction(self):
        """Generates a move using the AI model."""
        history_summary = self.get_history_summary()
        instruction = f"""
        We are playing a simplified version of Rock-Paper-Scissors (RPS) designed to evaluate your ability to predict the opponent's next move.

        ### **Game Rules:**
        1. The opponent follows a hidden strategy (which may involve a repeating pattern or adaptive behavior).
        2. Your task is to predict the opponent’s next move: Rock, Paper, or Scissors.
        3. If your prediction is correct, you earn **1 point**; otherwise, you earn **0 points**.
        4. The game2x2 continues for multiple rounds, and your accuracy is evaluated at each round.

        ### **Game History So Far:**
        {history_summary}

        ### **Your Task:**
        Based on the game2x2 history, predict the opponent's next move.  
        Return your response in JSON format with two keys:  
        - `"prediction"`: Your predicted move (`"Rock"`, `"Paper"`, or `"Scissors"`).  
        - `"reasoning"`: A brief explanation of how you made your prediction.
        """

        is_pagoda_model = ":" in self.model
        if is_pagoda_model:
            return await self.run_pagoda(instruction)


        for attempt in range(1, self.max_retries + 1):
            agent = AssistantAgent(
                name="Player",
                model_client=self.model_client,
                system_message="You are a helpful assistant."
            )
            if self.debug:
                print(f"Attempt {attempt}: {instruction}")
            response = await agent.on_messages(
                [TextMessage(content=instruction, source="user")],
                cancellation_token=CancellationToken(),
            )
            try:
                response_data = response.chat_message.content
                agent_response = AgentResponse.model_validate_json(response_data)
                move, reasoning = agent_response.prediction, agent_response.reasoning

                if move in ["Rock", "Paper", "Scissors"]:
                    return move, reasoning
            except (ValidationError, json.JSONDecodeError) as e:
                print(f"Error parsing response (Attempt {attempt}): {e}")
        raise ValueError("Model failed to provide a valid response after multiple attempts.")


    # Inside the Guess class
    async def run_pagoda(self, instruction: str):
        headers = {
            "Authorization": f"Bearer {os.getenv('PAGODA_API_KEY')}",
            "Content-Type": "application/json"
        }
        payload = {
            "model": self.model,
            "temperature": self.temperature,
            "prompt": instruction,
            "stream": False
        }

        for attempt in range(self.max_retries):
            try:
                response = requests.post(self.base_url, headers=headers, json=payload)
                response.raise_for_status()
                response_data = response.json()

                raw_response = response_data.get("response", "")
                parsed_json = self.extract_json_from_response(raw_response)

                if not parsed_json:
                    print(f"Failed to parse JSON (Attempt {attempt + 1}): {raw_response}")
                    continue

                agent_response = AgentResponse(**parsed_json)
                if agent_response.prediction in ["Rock", "Paper", "Scissors"]:
                    return agent_response.prediction, agent_response.reasoning
            except Exception as e:
                print(f"Error in run_pagoda (Attempt {attempt + 1}): {e}")

        raise ValueError("run_pagoda failed to get a valid response.")


    def extract_json_from_response(self, text: str) -> dict:
        """Extract JSON object from raw model output."""
        try:
            json_str = re.search(r"\{.*\}", text, re.DOTALL)
            if json_str:
                return json.loads(json_str.group())
        except Exception as e:
            print(f"Error extracting JSON: {e}")
        return {}

    def apply_strategy(self):
        """Predicts the next move using a heuristic."""
        if self.model == "gpt-4.5-preview-2025-02-27":
            if not self.history:
                return random.choice(["Rock", "Paper", "Scissors"]), "No history available. Choosing randomly."
            move_counts = {"Rock": 0, "Paper": 0, "Scissors": 0}
            for round_data in self.history:
                move_counts[round_data["Opponent Move"]] += 1
            most_common_move = max(move_counts, key=move_counts.get)
            reasoning = f"Based on history, the opponent most frequently played {most_common_move}."
            return most_common_move, reasoning
        elif self.model == "mistral-small":
            if not self.history:
                return "Scissors", "No game2x2 history available."
            opponent_moves = [move['Opponent Move'] for move in self.history]
            move_count = {
                'Rock': opponent_moves.count('Rock'),
                'Paper': opponent_moves.count('Paper'),
                'Scissors': opponent_moves.count('Scissors')
            }
            max_move = max(move_count, key=move_count.get)
            reasoning = f"Predicted {max_move} because it has been played {move_count[max_move]} times."
            return max_move, reasoning
        elif self.model in ["llama3", "deepseek-r1"]:
            return "Rock", f"Fallback strategy used for model: {self.model}."
        elif self.model == ("llama3.3:latest"):
            if not self.history:
                # First round, make an arbitrary choice
                return "Rock", "First round guess."
            rock_count = sum(1 for r in self.history if r['Opponent Move'] == 'Rock')
            paper_count = sum(1 for r in self.history if r['Opponent Move'] == 'Paper')
            scissors_count = sum(1 for r in self.history if r['Opponent Move'] == 'Scissors')
            # Predict the next move based on the most common opponent move
            max_count = max(rock_count, paper_count, scissors_count)
            if max_count == rock_count:
                strategy_move = "Paper"  # Paper beats Rock
            elif max_count == paper_count:
                strategy_move = "Scissors"  # Scissors beats Paper
            else:
                strategy_move = "Rock"  # Rock beats Scissors
            return strategy_move, f"Strategy chose {strategy_move} based on opponent's move history."
        elif self.model == "mixtral:8x7b":
            recent_moves = self.history
            if len(self.history) >= 3:
                recent_moves = self.history[-3:]
                return ["Rock", "Paper", "Scissors"][self.history.index(recent_moves[-1]) % 3][-1],  "Recent move"
            else:
                # Otherwise, use a simple strategy based on the last move
                opponent_last_move = recent_moves[-1] if recent_moves else None
                if not opponent_last_move:
                    return "Rock", "Recent move"
                else:
                    # Winning combinations
                    if opponent_last_move == "Scissors":
                        return "Paper", "Recent move"
                    elif opponent_last_move == "Paper":
                        return "Rock", "Recent move"
                    elif opponent_last_move == "Rock":
                        return "Scissors", "Recent move"
                    else:
                        return "Rock", "Recent move"
        elif self.model == "deepseek-r1:7b":
            moves = ["Rock", "Paper", "Scissors"]
            return moves[len(self.history) % 3], "making decisions in a cyclic manner"
        else:
            return "Scissors", f"Unknown model '{self.model}'. Defaulting to Scissors."

    @staticmethod
    def determine_accuracy(player_move: str, opponent_move: str) -> int:
        """Determines the accuracy of the prediction."""
        return 1 if player_move == opponent_move else 0

    def update_score(self, outcome: int):
        """Updates the score based on the outcome."""
        if outcome == 1:
            self.player_score_game += 1

    def get_history_summary(self) -> str:
        """Summarizes the game2x2 history for model-based predictions."""
        if not self.history:
            return "This is the first round."
        summary = "\n".join(
            [f"Round {i+1}: You guessed {r['Agent Prediction']}, Opponent played {r['Opponent Move']}. Outcome: {r['Outcome']}"
             for i, r in enumerate(self.history)]
        )
        summary += f"\nCurrent Score - You: {self.player_score_game}\n"
        return summary


    def simple_opponent_strategy(history):
        """A simple opponent strategy that cycles through Rock, Paper, Scissors."""
        moves = ["Rock", "Paper", "Scissors"]
        return moves[len(history) % 3]


async def main():
    # Play with strategy-based approach
    game = Guess(model="qwen3", temperature=0.7, game_id=1, opponent_strategy_fn=lambda history: "Rock", strategy=False)#  "llama3.3:latest", "mixtral:8x7b", "deepseek-r1:7b"
    num_rounds = 10
    for round_id in range(1, num_rounds + 1):
        result = await game.play_round(round_id)
        print(f"Round {round_id}: {result}")
    print(f"Final Score: {game.player_score_game}")

if __name__ == "__main__":
    asyncio.run(main())