import numpy as np
import scipy.stats as st
import matplotlib.pyplot as plt
import os

from experiments.utils import get_game_folder, decode_game_name
from games.jax_game import JaxGame 

import argparse

parser = argparse.ArgumentParser()

parser.add_argument("--game_details", type=str, default="goofspiel|4", help="Game details")
parser.add_argument("--k", type=int, default=4, help="Number of clusters")
parser.add_argument("--amount_seeds", type=int, default=5, help="Number of seeds")
parser.add_argument("--sim_types", type=list, default=["legal", "policy"], help="Type of similarity to use.")


def mean_confidence_interval(data, confidence=0.95):
  mean = np.mean(data, -1)
  sem = st.sem(data, -1)
  lower, upper = st.t.interval(confidence, data.shape[-1] - 1, loc=mean, scale=sem)
  return mean, np.where(np.isnan(lower), mean, lower), np.where(np.isnan(upper), mean, upper)

def plot_kmeans_exploitability_from_saved_values(game: JaxGame, sim_types: list[str], max_k:int, amount_seeds:int): 
  exps = np.zeros((2, len(sim_types), max_k, amount_seeds))
  
  for k in range(1, max_k):
    for sim_type_idx, sim_type in enumerate(sim_types):
      game_folder = f"data/logs/tabular/goofspiel_{game.cards}_descending/"
      path = game_folder +  sim_type + "_" + str(k) + ".txt"
      with open(path, "r") as f:
        for line in f.readlines():
          if line.startswith("Seed"):
            seed = int(line.split("|")[0].split(":")[1])  
            p1_exp = float(line.split("|")[1])
            p2_exp = float(line.split("|")[2])
            exps[:, sim_type_idx, k, seed] = (p1_exp, p2_exp)

  exps = np.mean(exps, 0)
  exps = exps[:, 1:]
  colors  = ["r", "b"]
  xs = np.arange(1, max_k)
  
  sim_labels = ["Legal actions", "Nash Equilibrium Strategy"]
  for sim_type_idx, sim_type in enumerate(sim_types):
    mean_exps, lower_exps, upper_exps = mean_confidence_interval(exps[sim_type_idx])
    plt.plot(xs, mean_exps, color=colors[sim_type_idx], label=sim_labels[sim_type_idx])
    plt.fill_between(xs, lower_exps, upper_exps, color=colors[sim_type_idx], alpha=0.15)
  
  plt.xlabel("Abstraction Limit", fontsize=18)
  plt.ylabel("Exploitability", fontsize=18)
  plt.legend(fontsize=15)
  plot_folder = get_game_folder(game, "plot") + "/"
  os.makedirs(plot_folder, exist_ok=True)
  plt.savefig(plot_folder + "kmeans_exploitability.png")
  plt.close() 


if __name__ == "__main__": 
  args = parser.parse_args()
  game = decode_game_name(args.game_details)
   
  plot_kmeans_exploitability_from_saved_values(game, args.sim_types, args.k, args.amount_seeds)
