import os
import csv
import asyncio
import random

from rps import RPS
from typing import Callable

CSV_FILE_PATH = "../../data/rps/rps.csv"

class RPSExperiment:
    def __init__(self):
        self.debug = False
        self.strategy = False
        self.models = ["qwen3"]
        self.opponent_strategies = {
            "always_rock": lambda history: "Rock",
            "always_paper": lambda history: "Paper",
            "always_scissor": lambda history: "Scissors",
            # "R-P": self.loop_R_P,
            # "P-S": self.loop_P_S,
            # "S-R": self.loop_S_R,
            # "R-P-S": self.loop_R_P_S
        }
        self.temperature = 0.7
        self.rounds = 10
        self.num_games_per_config = 30 #10
        self.initialize_csv()

    def loop_R_P(self, history):
        return "Rock" if len(history) % 2 == 0 else "Paper"

    def loop_P_S(self, history):
        return "Paper" if len(history) % 2 == 0 else "Scissors"

    def loop_S_R(self, history):
        return "Scissors" if len(history) % 2 == 0 else "Rock"

    def loop_R_P_S(self, history):
        strategies = ["Rock", "Paper", "Scissors"]
        return strategies[len(history) % 3]

    def initialize_csv(self):
        if not os.path.exists(CSV_FILE_PATH):
            os.makedirs(os.path.dirname(CSV_FILE_PATH), exist_ok=True)
            with open(CSV_FILE_PATH, mode="w", newline="") as file:
                writer = csv.writer(file)
                writer.writerow([
                    "idGame", "model", "opponentStrategy", "idRound",
                    "playerMove", "opponentMove", "outcomeRound",
                    "currentPlayerScoreGame", "motivations"
                ])

    def sanitize_motivations(self, motivations: str) -> str:
        sanitized = motivations.replace('"', '""').replace('\n', ' ').replace('\r', '')
        if sanitized and sanitized[0] in ('=', '+', '-', '@'):
            sanitized = "'" + sanitized
        return f'"{sanitized}"'

    def log_to_csv(self, game_id, model, opponent_strategy, round_id,
                   agent_move, opponent_move, outcome, player_score_game, motivations):
        sanitized_motivations = self.sanitize_motivations(motivations)
        model_type = model + " strategy" if self.strategy else model
        with open(CSV_FILE_PATH, mode="a", newline="") as file:
            writer = csv.writer(file)
            writer.writerow([
                game_id, model_type, opponent_strategy, round_id,
                agent_move, opponent_move, outcome, player_score_game,
                sanitized_motivations
            ])

    async def run_experiment(self):
        game_id = 1
        for model in self.models:
            if self.debug:
                print(f"Running model {model}")
            for strategy_name, strategy_fn in self.opponent_strategies.items():
                if self.debug:
                    print(f"Running strategy {strategy_name}")
                for _ in range(self.num_games_per_config):
                    if self.debug:
                        print(f"Running game {game_id}")
                    await self.run_game(model, strategy_name, strategy_fn, game_id)
                    game_id += 1

    async def run_game(self, model, opponent_strategy_name, opponent_strategy_fn, game_id):
        game = RPS(
            model=model,
            temperature=self.temperature,
            game_id=game_id,
            opponent_strategy_fn=opponent_strategy_fn,
            strategy=self.strategy
        )
        for i in range(1, self.rounds + 1):
            round_data = await game.play_round(i)  # Make sure play_round is synchronous
            self.log_to_csv(
                game_id, model, opponent_strategy_name, i,
                round_data["Agent Move"], round_data["Opponent Move"],
                round_data["Outcome"], game.player_score_game,
                round_data["Motivations"]
            )

if __name__ == "__main__":
    experiment = RPSExperiment()
    asyncio.run(experiment.run_experiment())
    print("Experiment completed. Results saved in", CSV_FILE_PATH)