import numpy as np
import matplotlib.pyplot as plt
import os


def compute_mean_rnad(save_folder: str):
  p1_overall_results = []
  p2_overall_results = []
  for file in os.listdir(save_folder):
    if file.endswith("_rnad.txt"):
      p1_results = []
      p2_results = []
      with open(os.path.join(save_folder, file), "r") as f:
        lines = f.readlines()
        for line in lines:
          if line.startswith("P1: "):
            p1_results.append(float(line.split(":")[1].strip()))
          if line.startswith("P2: "):
            p2_results.append(float(line.split(":")[1].strip()))
      p1_overall_results.append(np.mean(p1_results[-30:]))
      p2_overall_results.append(np.mean(p2_results[-30:]))
  return (np.mean(p1_overall_results) + np.mean(p2_overall_results)) / 2

# Similarity metric is a list of tuples of (name, seed_offset)
def load_saved_data(ks: list[int], similarity_metric: list[tuple[str, int]], evaluation_experiments: list[str], save_folder: str, max_iter: int, amount_seeds: int):
  
  results = np.zeros((2, len(ks), amount_seeds, len(evaluation_experiments), len(similarity_metric), max_iter))
  
  for k_i, k in enumerate(ks):
    for sim_id, (similarity, sim_offset) in enumerate(similarity_metric):
      for seed in range(amount_seeds):
        for eval_id, evaluation in enumerate(evaluation_experiments):
          final_seed = seed + k * 10 + sim_offset 
          file_name = f"{save_folder}/seed_{final_seed}_{evaluation}.txt"
          if  not os.path.exists(file_name):
            print(f"Did not file log for: {file_name}")
            continue
          iter_id = 0
          with open(file_name, "r") as f:
            lines = f.readlines()
            for line in lines:
              if line.startswith("P1: "):
                p1_exploitability = float(line.split(":")[1].strip())
                results[0, k_i, seed, eval_id, sim_id, iter_id] = p1_exploitability
              if line.startswith("P2: "):
                p2_exploitability = float(line.split(":")[1].strip())
                results[1, k_i, seed, eval_id, sim_id, iter_id] = p2_exploitability
                iter_id += 1
              if iter_id >= max_iter:
                break
            if iter_id < max_iter:
              print(f"Did not find max iter for: {file_name}")
              
  return results


def plot_similarity_metrics_bar(ks: list[int], save_folder: str, max_iter: int, amount_seeds: int, similarity_metric: list[tuple[str, int]], evaluation_experiments: list[str], rnad_exploitability: float):
  
  similarity_labels = {
    "legal_actions": "Legal Actions",
    "policy": "RNaD Strategy",
    "legal_policy": "Legal Actions + RNaD Strategy",
    "action_history": "Action History",
    "action_history_legal_policy": "Action History + Legal Actions + RNaD Strategy"
  }
  evaluation_labels = {
    "isets_per_depth": "Each subgame",
    "full_game": "Full game",
    "no_dynamics_full_game": "Only abstraction",
    "trained_dynamics_full_game": "Trained Dynamics Full Game",
    "rnad": "RNaD"
  }
  results = load_saved_data(ks, similarity_metric, evaluation_experiments, save_folder, max_iter, amount_seeds)
  results = np.mean(results, axis=(0,))[..., -30:]  # Average over players and select last 30 iterations
  mean_results = np.mean(results, axis=(-1,)) 
  
  # Calculate mean and standard error across seeds
  seed_mean = np.mean(mean_results, axis=(1,))  # Average over seeds
  seed_std = np.std(mean_results, axis=(1,))  # Standard deviation over seeds
  seed_sem = seed_std / np.sqrt(amount_seeds)  # Standard error of the mean
  confidence_interval = 1.96 * seed_sem  # 95% confidence interval
  
  # Create a directory for plots if it doesn't exist
  plots_dir = f"{save_folder}/plots"
  os.makedirs(plots_dir, exist_ok=True)
  
  # Define colors for each evaluation
  colors = plt.cm.Set3(np.linspace(0, 1, len(similarity_metric)))
  
  # Calculate bar positions with more spacing
  bar_width = 0.7 / len(evaluation_experiments)  # Slightly narrower bars
  x = np.arange(len(ks))
  # Add horizontal line for RNAD exploitability
  # Create bars for each similarity metric
  for eval_id, evaluation in enumerate(evaluation_experiments):
    plt.figure(figsize=(15, 8))
    
    for sim_id, (similarity, sim_offset) in enumerate(similarity_metric):
      # Calculate position for each bar
      pos = x + sim_id * bar_width
      if "goofspiel" in save_folder:
        plt.ylim(0, 0.4)  # Force y-axis to be between 0 and 0.4
      elif "oshi" in save_folder:
        plt.ylim(0, 0.08)  # Force y-axis to be between 0 and 0.4
      elif "leduc" in save_folder:
        plt.ylim(0, 0.6)  # Force y-axis to be between 0 and 0.4
      # Plot bars with error bars
      bars = plt.bar(pos, seed_mean[:, eval_id, sim_id], width=bar_width, 
                     label=similarity_labels[similarity], color=colors[sim_id], yerr=confidence_interval[:, eval_id, sim_id],
                     capsize=5)
      
      # Add value labels on top of each bar
      for bar_id, (bar, value) in enumerate(zip(bars, seed_mean[:, eval_id, sim_id])):
        height = bar.get_height()
        ci = confidence_interval[bar_id, eval_id, sim_id]
        # Position text above the error bar with some padding
        text_y = height + ci + 0.02
        # Format the value and confidence interval
        text = f'{value:.3f}'
        plt.text(bar.get_x() + bar.get_width()/2., text_y,
                text, ha='center', va='bottom', fontsize=15, rotation=0)
    plt.axhline(y=rnad_exploitability, color='b', linestyle='--', label=f'RNaD ({rnad_exploitability:.3f})', alpha=0.2)
    
    label_font = 18 
    
    plt.xlabel("Abstraction Limit", fontsize=label_font + 4)
    plt.ylabel("Exploitability", fontsize=label_font + 4)
    plt.xticks(x + bar_width * (len(similarity_metric) - 1) / 2, ks, fontsize=label_font)
    
    # Add RNaD value to y-ticks
    current_yticks = list(plt.yticks()[0])
    if rnad_exploitability not in current_yticks:
        current_yticks.append(round(rnad_exploitability, 3))
        current_yticks.sort()
        plt.yticks(current_yticks, fontsize=label_font)
    
    plt.legend(fontsize=label_font)
    plt.grid(True, linestyle='--', alpha=0.7)
    # Adjust y-axis to accommodate error bars and value labels 
    
    # Adjust layout to prevent text cutoff
    plt.tight_layout()
    # Save the plot
    plt.savefig(f"{plots_dir}/evaluation_{evaluation_labels[evaluation].replace(' ', '_')}_bar.png", dpi=300, bbox_inches='tight')
    plt.close()
    
  
def plot_results_per_k(ks: list[int], save_folder: str, max_iter: int, amount_seeds: int, similarity_metric: list[tuple[str, int]], evaluation_experiments: list[str], rnad_exploitability: float):
  results = load_saved_data(ks, similarity_metric, evaluation_experiments, save_folder, max_iter, amount_seeds) 
  results = np.mean(results, axis=(0,))[..., -30:]
  
  
  evaluation_labels = {
    "no_dynamics": "Information abstraction",
    "with_dynamics": "Information abstraction + dynamics", 
  }
  
  mean_results = np.mean(results, axis=(-1,))
  mean_results = np.squeeze(mean_results)
  
  
  seed_mean = np.mean(mean_results, axis=(1,))  # Average over seeds
  seed_std = np.std(mean_results, axis=(1,))  # Standard deviation over seeds
  seed_sem = seed_std / np.sqrt(amount_seeds)  # Standard error of the mean
  confidence_interval = 1.96 * seed_sem  # 95% confidence interval
  plots_dir = f"{save_folder}/plots"
  os.makedirs(plots_dir, exist_ok=True) 
  colors = ["r", "b"]
  plt.ylim(0, 3)
  plt.axhline(y=rnad_exploitability, color='b', linestyle='--', label=f'RNaD ({rnad_exploitability:.3f})', alpha=0.2)
  for eval_id, evaluation in enumerate(evaluation_experiments):
    plt.plot(range(1, len(ks) + 1), seed_mean[:, eval_id,], label=f"{evaluation_labels[evaluation]}", color=colors[eval_id])
    plt.fill_between(range(1, len(ks) + 1), 
                    seed_mean[:, eval_id,] - confidence_interval[:, eval_id,],
                    seed_mean[:, eval_id,] + confidence_interval[:, eval_id,],
                    alpha=0.2, color=colors[eval_id])
    
  current_yticks = list(plt.yticks()[0])
  if rnad_exploitability not in current_yticks:
      current_yticks.append(round(rnad_exploitability, 3))
      current_yticks.sort()
      plt.yticks(current_yticks)
      
  plt.legend()
  plt.xlabel("Abstraction Limit")
  plt.ylabel("Exploitability")
  plt.savefig(f"{plots_dir}/mean_results.png", dpi=300, bbox_inches='tight')
  plt.close()    
    
def goofspiel_5_plot():
  save_folder = "muzero_logs/goofspiel_5_descending"
  ks = [5, 10, 20, 30]
  max_iter = 99
  similarities = [
    ("legal_actions", 0), 
    # ("policy", 1000), 
    ("legal_policy", 2000), 
    # ("action_history", 3000), 
    ("action_history_legal_policy", 4000)
  ]
  evaluations = [
    "isets_per_depth", 
    "full_game", 
    "no_dynamics_full_game", 
    # "trained_dynamics_full_game", 
  ]
  
  plot_similarity_metrics_bar(ks, save_folder, max_iter, 10, similarities, evaluations, compute_mean_rnad(save_folder)) 
      

def oshi_zumo_3_5_plot():
  save_folder = "muzero_logs/oshi_zumo_3_5"
  max_iter = 35
  ks = [5, 10, 15]
  similarities = [
    ("legal_actions", 0), 
    # ("policy", 1000), 
    ("legal_policy",2000), 
    # ("action_history", 3000), 
    ("action_history_legal_policy", 2000)
  ]
  evaluations = [
    "isets_per_depth", 
    "full_game", 
    "no_dynamics_full_game", 
    # "trained_dynamics_full_game", 
    # "rnad"
  ]
  plot_similarity_metrics_bar(ks, save_folder, max_iter, 10, similarities, evaluations, compute_mean_rnad(save_folder)) 
    
     
def leduc_plot():
  save_folder = "muzero_logs/leduc"
  ks = list(range(1, 7))
  max_iter = 35
  similarities = [
    ("iset_policy", 1000), 
  ]
  evaluations = [
    "no_dynamics", 
    "with_dynamics",  
  ] 
  plot_results_per_k(ks, save_folder, max_iter, 10, similarities, evaluations, compute_mean_rnad(save_folder)) 
  
    
if __name__ == "__main__":
  goofspiel_5_plot()
  oshi_zumo_3_5_plot()