import matplotlib.axes
import matplotlib.pyplot as plt
import matplotlib
import jax.numpy as jnp
from jax.tree_util import tree_map
from src.alg import InnerLoopData
from typing import List
from src.utils import dist_equilibrium



def plot_2d_trajectory(trajectory: list, ax: matplotlib.axes.Axes, name=None, color=None):

    xs, ys = list(zip(*trajectory))
    ax.plot(xs, ys, label=name, color=color)
    ax.grid()

def simplex_projection(x):
  assert x.shape[0] == 3
  return x[0] * jnp.array([0, 1]) + x[1] * jnp.array([-1/2, -1/2]) + x[2] * jnp.array([1/2, -1/2])

def plot_corners(ax: matplotlib.axes.Axes):
  projs = []
  for x in [jnp.array([1, 0, 0]), jnp.array([0, 1, 0]), jnp.array([0, 0, 1])]:
        projs.append(simplex_projection(x))
  projs.append(projs[0])
  for (c1, d1), (c2, d2) in zip(projs, projs[1:]):
    ax.plot([c1, c2], [d1, d2], color="black")

def plot_simplex_points(ax: matplotlib.axes.Axes,  points, name=None):
  projected_points = tree_map(simplex_projection, points)
  ax.plot(*zip(*projected_points), label=name, linewidth=1)

def plot_simplex_trajectory(axs, runs: list, names):
    ax0 =axs[0] 
    ax1 = axs[1]
    ax0.set_axis_off()
    ax1.set_axis_off()
    ax0.set_title("Player 1")
    ax1.set_title("Player 2")
    plot_corners(ax0)
    plot_corners(ax1)
    for r, name in zip(runs, names):
      p1, p2 = list(zip(*r))
      plot_simplex_points(ax0, p1, name)
      plot_simplex_points(ax1, p2, name)
    

def plot_parameter_latent_2d(axs, parameter_latent_lists: list, names=None, subset=None):
    
    ax0 = axs[0]
    ax1 = axs[1]
    ax0.set_title("$\\theta_t$ trajectory")
    ax1.set_title("$z_t = g(\\theta_t)$ trajectory")
    cmap = plt.get_cmap("tab10")
    if not subset:
      subset = list(range(len(parameter_latent_lists)))
    for i in subset:
      p, l = parameter_latent_lists[i]

      plot_2d_trajectory(p, ax0, None, cmap(i))
      if names:
        plot_2d_trajectory(l, ax1, names[i], cmap(i))
      else:
        plot_2d_trajectory(l, ax1, cmap(i))
  

def plot_dist_equilibrium(ax: matplotlib.axes.Axes, distances, equilibrium: jnp.array, ci=None, time=None, name=None):
  
  if time:
    ax.plot(time, distances, label=name)
    ax.set_xlabel("time (s)")

  else:
    p = ax.plot(distances, label=name)
    if ci is not None:
      col = p[-1].get_color()
      ax.fill_between(list(range(len(distances))), (distances-ci), (distances+ci), color=col, alpha=.1)
       
  ax.set_yscale('log')
  

def plot_runs_dist_equilibrium(ax, latents_lists, equilibrium, names=None):
  
  ax.set_ylabel(r"$\frac{1}{2}\|z_t - z_{\ast}\|^2$")
  ax.set_xlabel("Outer loop iterations")

  for i, latents in enumerate(latents_lists):

    name = names[i] if names else None
    distances =  dist_equilibrium(latents, equilibrium)
    plot_dist_equilibrium(ax, distances, equilibrium, name=name)
  
  
def plot_benchmark_times(names, times, filename=None):
  fig, ax = plt.subplots()
  ax.set_title("Time for 10k Updates (s)")
  ax.bar(names, times)
  ax.tick_params(axis='x', rotation=35)

  if filename:
    fig.savefig(filename, bbox_inches='tight')

def plot_loss_ratio(ax, data_list: List[List[InnerLoopData]], names=None, subset=None):
  
  ax.set_title(f"$\ell_t(\\theta_{{t+1}})/\ell_t(\\theta_t)$")
  ax.set_xlabel("Outer loop iterations")
  cmap = plt.get_cmap("tab10")
  if not subset:
      subset = list(range(len(data_list)))

  for i in subset:
    data = data_list[i]
    loss_ratio = [jnp.sqrt(d.loss_end / d.loss_beg) for d in data]    
    if names:
      ax.plot(loss_ratio, label=names[i], color=cmap(i))
    else:
      ax.plot(loss_ratio)

def plot_inner_loop_data(data_list: List[List[InnerLoopData]], names=None, title=""):
  
  fig, (ax0, ax1, ax2) = plt.subplots(nrows=1, ncols=3)
  fig.suptitle(title)
  
  ax0.set_title("Loss at Begining of Trajectory")
  ax1.set_title("Loss Ratio")
  ax2.set_title("Min SV")

  for i, data in enumerate(data_list):
    loss_beg = [d.loss_beg for d in data]
    loss_ratio = [jnp.sqrt(d.loss_end / d.loss_beg) for d in data]
    smallest_sv = [d.min_singular_value for d in data]
    ax0.plot(loss_beg)
    ax1.plot(loss_ratio)
    if names:
      ax2.plot(smallest_sv, label=names[i])
    else:
      ax2.plot(smallest_sv)

  if names:
    ax2.legend()
  
  
def figure_matching_pennies(trajectories, equilibrium, data_list, names, subsets, filename):
  fig, axs = plt.subplots(nrows=1, ncols=4, figsize=(16, 3.75))
  latents_lists = [t[1] for t in trajectories]
  plot_runs_dist_equilibrium(axs[0], latents_lists, equilibrium, names)

  plot_parameter_latent_2d(axs[1:3], trajectories, names, subsets[0])
  plot_loss_ratio(axs[3], data_list, names, subsets[1])
  handles, labels = axs[0].get_legend_handles_labels()
  fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.1), ncol=5)
  fig.savefig(filename, bbox_inches='tight')

def shorten_name(name):
    alg = name.split("(")[0]
    if "inner" in name:
        inner = name.split("(")[1].split(',')[0].split('inner=')[1]
        return f"{alg}({inner})"
    else:
        return alg


def figure_rps(trajectories, equilibrium, data_list, names, filename):
  
  fig, axs = plt.subplots(nrows=2, ncols=2)
  fig.subplots_adjust(wspace=0.2, hspace=0.1)
  latents_lists = [t[1] for t in trajectories]
  plot_simplex_trajectory(axs[0], latents_lists, names)
  plot_runs_dist_equilibrium(axs[1][0], latents_lists, equilibrium, names)
  plot_loss_ratio(axs[1][1], data_list, names)
  # plot_benchmark_times(axs[1][1], [shorten_name(name) for name in names], times)
  handles, labels = axs[0][0].get_legend_handles_labels()
  fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=3)
  fig.savefig(filename, bbox_inches='tight')
  plt.clf()

def figure_rps_rand_init(distances, cis, equilibrium, names, filename):
  fig, ax = plt.subplots()
  ax.set_ylabel(r"$\frac{1}{2}\|z_t - z_{\ast}\|^2$")
  ax.set_xlabel("Outer loop iterations")

  for d, ci, name in zip(distances, cis, names):
     
     plot_dist_equilibrium(ax, d, equilibrium, ci, name=name)
  handles, labels = ax.get_legend_handles_labels()

  ncol = 5 if len(names) > 6 else 3

  fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, 1.05), ncol=ncol)
  fig.savefig(filename, bbox_inches='tight')
  plt.clf()