import os
import csv
import asyncio
import random
from http.cookiejar import debug

from mp import MP
from typing import Callable

CSV_FILE_PATH = "../../data/mp/mp_claude35.csv"


class MPExperiment:
    def __init__(self):
        self.debug = False
        self.strategy = False
        self.models = ["claude-3-7-sonnet"]   #"gpt-4.1", "gpt-4.1", "aws-claude-3-7-sonnet", "aws-claude-3-5-sonnet", "aws-claude-3-haiku", "qwen3", "llama3", "llama3.3", "mixtral", "mistral-small", "deepseek-r1"
        self.opponent_strategies = {
            #"always_head": lambda history: "Head"
            #"always_tail": lambda history: "Tail"
            "H-T": self.loop_H_T
            #"T-H": self.loop_T_H
        }
        self.temperature = 0
        self.rounds = 10
        self.num_games_per_config = 30
        self.initialize_csv()

    def loop_H_T(self, history):
        return "Head" if len(history) % 2 == 0 else "Tail"

    def loop_T_H(self, history):
        return "Tail" if len(history) % 2 == 0 else "Head"

    def initialize_csv(self):
        if not os.path.exists(CSV_FILE_PATH):
            os.makedirs(os.path.dirname(CSV_FILE_PATH), exist_ok=True)
            with open(CSV_FILE_PATH, mode="w", newline="") as file:
                writer = csv.writer(file)
                writer.writerow([
                    "idGame", "model", "opponentStrategy", "idRound",
                    "playerMove", "prediction", "opponentMove", "outcomeRound",
                    "currentPlayerScoreGame", "predictionRound", "currentPlayerPredictionScoreGame", "reasoning"
                ])

    def sanitize_reasoning(self, reasoning: str) -> str:
        sanitized = reasoning.replace('"', '""').replace('\n', ' ').replace('\r', '')
        if sanitized and sanitized[0] in ('=', '+', '-', '@'):
            sanitized = "'" + sanitized
        return f'"{sanitized}"'

    def log_to_csv(self, game_id, model, opponent_strategy, round_id,
                   agent_move, prediction, opponent_move, outcome,
                   player_score_game, prediction_round_score, prediction_total_score, temperature, reasoning):
        sanitized_reasoning = self.sanitize_reasoning(reasoning)
        model_type = model + " strategy" if self.strategy else model
        with open(CSV_FILE_PATH, mode="a", newline="") as file:
            writer = csv.writer(file)
            writer.writerow([
                game_id, model_type, opponent_strategy, round_id,
                agent_move, prediction, opponent_move, outcome,
                player_score_game, prediction_round_score, prediction_total_score, self.temperature, sanitized_reasoning
            ])

    async def run_experiment(self):
        game_id = 1
        for model in self.models:
            if self.debug:
                print(f"Running model {model}")
            for strategy_name, strategy_fn in self.opponent_strategies.items():
                if self.debug:
                    print(f"Running strategy {strategy_name}")
                for _ in range(self.num_games_per_config):
                    if debug:
                        print(f"Running game {game_id}")
                    await self.run_game(model, strategy_name, strategy_fn, game_id)
                    game_id += 1

    async def run_game(self, model, opponent_strategy_name, opponent_strategy_fn, game_id):
        game = MP(
            model=model,
            temperature=self.temperature,
            game_id=game_id,
            prediction=True,
            opponent_strategy_fn=opponent_strategy_fn,
            strategy=self.strategy
        )
        for i in range(1, self.rounds + 1):
            round_data = await game.play_round(i, self.rounds)  # Make sure play_round is synchronous
            prediction_round_score = 1.0 if round_data.get("Prediction") == round_data.get("Opponent Move") else 0.0
            prediction_total_score = game.prediction_score

            self.log_to_csv(
                game_id, model, opponent_strategy_name, i,
                round_data["Your Move"], round_data["Prediction"],
                round_data["Opponent Move"], round_data["Outcome"],
                game.player_score_game, prediction_round_score, prediction_total_score, self.temperature,
                round_data["Reasoning"]
            )


if __name__ == "__main__":
    experiment = MPExperiment()
    asyncio.run(experiment.run_experiment())
    print("Experiment completed. Results saved in", CSV_FILE_PATH)