# -*- coding: utf-8 -*-
"""
Created on Mon Apr 29 21:43:15 2024

@author: javie
"""
# visualization

import numpy as np
import jax.numpy as jnp

from sklearn.cluster import KMeans
from scipy.ndimage import gaussian_filter1d

import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
from ipywidgets import widgets

def vis_basis(basis,shape,cluster=True):
    """
    
    Parameters
    ----------
    basis : TYPE
        Linear Operator representing the "basis" of the corresponding vector_space.
    shape : tuple
        sizes of (rep_out, rep_in).
    cluster : bool, optional
        Whether similar values will be clustered for the visualization. The default is True.

    Returns
    -------
    None.

    """
    Q=basis@jnp.eye(basis.shape[-1]) # convert to a dense matrix if necessary
    v = np.random.randn(Q.shape[0])  # sample random vector
    v = Q@(Q.T@v)                    # project onto equivariant subspace
    if cluster: # cluster nearby values for better color separation in plot
        v = KMeans(n_clusters=Q.shape[-1], n_init=10).fit(v.reshape(-1,1)).labels_
    plt.imshow(v.reshape(shape))
    plt.axis('off')

def vis(repin,repout,cluster=True):
    """

    Parameters
    ----------
    repin : emlp.representation
        representation of the group action on the input space.
    repout : emlp.representation
        representation of the group action on the output space.
    cluster : TYPE, optional
        DESCRIPTION. The default is True.

    Returns
    -------
    None.

    """
    Q = (repin>>repout).equivariant_basis() # compute the equivariant basis
    vis_basis(Q,(repout.size(),repin.size()),cluster) # visualize it
    
    
    
# def plot_losses(train_losses, N_p=40, particle_positions_teacher=None, tau = 0):
#   train_loss_arr = jnp.array(train_losses).squeeze()

#   norm_penalization = ((particle_positions_teacher**2).sum(axis = 1)).mean() if not particle_positions_teacher is None else 0

#   sns.set_style("whitegrid")
#   plt.plot(gaussian_filter1d(train_loss_arr,2), label = '$N_p={}$'.format(N_p))
#   sns.set_style("whitegrid")
#   plt.yscale('log')
#   plt.legend()
#   plt.xlabel('iteration (K)',fontsize=15)
#   plt.legend(fontsize=12)
#   plt.xticks(fontsize=12)
#   plt.yticks(fontsize=12)
#   plt.ylabel('loss (test)',fontsize=15)
#   plt.axhline(y=tau*norm_penalization, color='r', linestyle='-')
#   plt.show()
  
  
def plot_losses(train_losses, particle_positions_teacher=None, tau = 0, labels = None, N_p=40):
    """

    Parameters
    ----------
    train_losses : list of arrays
        list containing potentially multiple "training loss curves" to be plotted.
    particle_positions_teacher : array, optional
        array of shape (N_particles, dimension) containing the "particles of the teacher network".
        The default is None.
    tau : float, optional
        regularization parameter considered during training. The default is 0.
    labels : list of strings, optional
        list containing the labels associated with each curve to be plotted. The default is None.
    N_p : int, optional
        Number of particles used for training the network (to be deprecated). The default is 40.

    Returns
    -------
    None. But plots the training loss curves along with the "optimal" loss given by the teacher network.

    """
    norm_penalization = ((particle_positions_teacher**2).sum(axis = 1)).mean() if not particle_positions_teacher is None else 0
    if not isinstance(train_losses, list):
        train_losses = [train_losses]
    sns.set_style("whitegrid")
    if labels is None:
        labels = ['$N_p={}$'.format(N_p) for _ in train_losses]
    for train_loss, label in zip(train_losses, labels):
        train_loss_arr = jnp.array(train_loss).squeeze()
        plt.plot(gaussian_filter1d(train_loss_arr,2), label = label)
    sns.set_style("whitegrid")
    plt.yscale('log')
    plt.legend()
    plt.xlabel('iteration (K)',fontsize=15)
    plt.legend(fontsize=12)
    plt.xticks(fontsize=12)
    plt.yticks(fontsize=12)
    plt.ylabel('loss (test)',fontsize=15)
    plt.axhline(y=tau*norm_penalization, color='r', linestyle='-')
    plt.show()
   
    
   
def create_plot_list(teacher_coords = None, particle_coords = None, equivariant_coords = None, c_lims = (-2,2), t_style = (4, "diamond") , p_style = (1.5, "circle"), eq_style = (0.35,0.8), with_lineplots = True):
    """
    

    Parameters
    ----------
    teacher_coords : array, optional
        array of shape (dimension, N_particles) containing the "particles of the teacher network". 
        The default is None.
    particle_coords : array, optional
        array of shape (dimension, N_particles) containing the "particles of the student network". 
        The default is None.
    equivariant_coords : array, optional
        array of shape (dimension, N_particles) containing the "endpoints of the equivariant space". 
        The default is None.
    c_lims : tuple of floats, optional
        limits for the colorscale to be used in the plots. The default is (-2,2).
    t_style : tuple of the form (float, str), optional
        tuple containing the "style" of the teacher network particles in the format (size, marker_style). The default is (4, "diamond").
    p_style : tuple of the form (float, str), optional
        tuple containing the "style" of the studen network particles in the format (size, marker_style). The default is (1.5, "circle").
    eq_style : tuple of the form (float, float), optional
        tuple containing other "style" parameters, such as (mesh_opacity, length of lines). The default is (0.35,0.8).
    with_lineplots : bool, optional
        boolean indicating whether straight lines will be traced to every teacher particle. The default is True.

    Returns
    -------
    plots : list of plotly.go objects to be plotted 
        list of plotly.go objects to be plotted.

    """
    cmin, cmax = c_lims
    plots = []
    if not equivariant_coords is None:
      ex1, ex2, ex3, ex4 = equivariant_coords
      equivariant_space_mesh = go.Mesh3d(x=ex1, y=ex2, z=ex3,opacity=eq_style[0], intensity = ex4, colorscale='Viridis', cmax=cmax,cmin=cmin, name='y', showscale=True, hoverinfo='none')
      plots.append(equivariant_space_mesh)
    if not teacher_coords is None:
      gtx1, gtx2, gtx3, gtx4 = teacher_coords
      scatter_gt =  go.Scatter3d(name='optimal positions', x=gtx1, y=gtx2, z=gtx3, mode='markers', marker=dict(size=t_style[0], color=gtx4, cmax=cmax,cmin=cmin, colorscale='Viridis', showscale=True, symbol=t_style[1]), hovertemplate='<b>Teacher Particle</b><extra></extra>') #, label = r'$\theta$ of NN'
      plots.append(scatter_gt)
    if not particle_coords is None:
      x1, x2, x3,x4 = particle_coords
      scatter_particles =  go.Scatter3d(name='NN parameters', x=x1, y=x2, z=x3, mode='markers', marker=dict(size=p_style[0], color=x4, cmax=cmax,cmin=cmin, colorscale='Viridis', showscale=True, symbol=p_style[1]), hoverinfo='none')#marker_color=x4,marker_size=0.8)
      plots.append(scatter_particles)
  
    if with_lineplots and (not teacher_coords is None):
      lines = jnp.einsum('a,bc->abc',jnp.linspace(-eq_style[1],eq_style[1],50),(teacher_coords/jnp.linalg.norm(teacher_coords.T, axis=1)).T).swapaxes(0,1)
      lineplots = []
      for line in lines:
        lx1, lx2, lx3, lx4 = line.T
        lineplots.append(go.Scatter3d(x=lx1, y=lx2, z=lx3,marker=dict(
            size=0.01),
          line=dict(
              color='red', #'darkblue',
              width=0.5
          ),
          showlegend=False, hoverinfo="none"))
      plots += lineplots
    return plots

def get_transpose(x):
    # function to transpose avoiding "Nones"
    return x.T if not x is None else None

def get_permuted(x, permutation = [0,1,3,2]):
    # function to permute avoiding "Nones"
    return x[permutation,:] if not x is None else None

def particle_plot(particle_positions_teacher, 
                  particle_positions_model, 
                  equivariant_space_limits,  
                  title = "4D plot", 
                  c_lims = (-2,2), 
                  t_style = (4, "diamond") , 
                  p_style = (1.5, "circle"), 
                  eq_style = (0.35,0.8), 
                  with_lineplots = True, 
                  double=False):
    """

    Parameters
    ----------
    particle_positions_teacher : array or None
        array of shape (N_particles, dimension) containing the "particles of the teacher network".
    particle_positions_model : array or None
        array of shape (N_particles, dimension) containing the "particles of the student network".
    equivariant_space_limits : array or None
        array of shape (N_endpoints, dimension) containing the "endpoints" for drawing the space E^G.
    title : str, optional
        Title of the plot. The default is "4D plot".
    c_lims : tuple of floats, optional
        limits for the colorscale to be used in the plots. The default is (-2,2).
    t_style : tuple of the form (float, str), optional
        tuple containing the "style" of the teacher network particles in the format (size, marker_style). The default is (4, "diamond").
    p_style : tuple of the form (float, str), optional
        tuple containing the "style" of the studen network particles in the format (size, marker_style). The default is (1.5, "circle").
    eq_style : tuple of the form (float, float), optional
        tuple containing other "style" parameters, such as (mesh_opacity, length of lines). The default is (0.35,0.8).
    with_lineplots : bool, optional
        boolean indicating whether straight lines will be traced to every teacher particle. The default is True.
    double : bool, optional
        boolean indicating wheter a single or dual plot will be made. The default is False.

    Returns
    -------
    widget.HBox object or None (if not double)
        widget object to be displayed in a jupyter notebook.

    """
    teacher_coords, particle_coords, equivariant_coords = get_transpose(particle_positions_teacher), get_transpose(particle_positions_model), get_transpose(equivariant_space_limits)
    plots = create_plot_list(teacher_coords, particle_coords, equivariant_coords, c_lims = c_lims, t_style = t_style , p_style = p_style, eq_style = eq_style, with_lineplots = with_lineplots)
  
    fig1 = go.Figure(data=plots)
    layout_config= dict(
        title=dict(text=title, x=0.5, font=dict(size=20), automargin=False),
        scene = dict(
              xaxis_title='x1',
              yaxis_title='x2',
              zaxis_title='x3',
              xaxis_showspikes= False,
              yaxis_showspikes= False,
              zaxis_showspikes= False,
              camera = dict(
                  up=dict(x=0, y=0, z=1),
                  center=dict(x=0, y=0, z=0),
                  eye=dict(x=1.25, y=-1.25, z=1.25)
              )
              ),
        width=600,
        height=450,
        autosize=False,
        legend=dict(
        yanchor="bottom",
        y=0.99,
        xanchor="left",
        x=0.01
      ))
    if not double:
      layout_config["width"], layout_config["height"] = 800, 600
      fig1.update_layout(layout_config)
      fig1.show()
    else:
      teacher_coords2, particle_coords2, equivariant_coords2 = get_permuted(teacher_coords), get_permuted(particle_coords), get_permuted(equivariant_coords)
      plots2 = create_plot_list(teacher_coords2, particle_coords2, equivariant_coords2, c_lims = c_lims, t_style = t_style , p_style = p_style, eq_style = eq_style, with_lineplots = with_lineplots)
  
      fig2 = go.Figure(data=plots2)
      fig1.update_layout(layout_config)
  
      layout_config["scene"]["zaxis_title"] = "x4"
      layout_config["scene"]["camera"] = dict(
                  up=dict(x=0, y=0, z=1),
                  center=dict(x=0, y=0, z=0),
                  eye=dict(x=-1.25, y=-1.25, z=1.25)
              )
      fig2.update_layout(layout_config)
      return widgets.HBox([go.FigureWidget(fig1), go.FigureWidget(fig2)])


def particle_plot_animation(particle_positions_teacher, 
                  particle_positions_model_L, 
                  equivariant_space_limits,  
                  title = "4D plot", 
                  c_lims = (-2,2), 
                  t_style = (4, "diamond") , 
                  p_style = (1.5, "circle"), 
                  eq_style = (0.35,0.8), 
                  with_lineplots = True):
    """

    Parameters
    ----------
    particle_positions_teacher : array or None
        array of shape (N_particles, dimension) containing the "particles of the teacher network".
    particle_positions_model : array or None
        array of shape (N_particles, dimension) containing the "particles of the student network".
    equivariant_space_limits : array or None
        array of shape (N_endpoints, dimension) containing the "endpoints" for drawing the space E^G.
    title : str, optional
        Title of the plot. The default is "4D plot".
    c_lims : tuple of floats, optional
        limits for the colorscale to be used in the plots. The default is (-2,2).
    t_style : tuple of the form (float, str), optional
        tuple containing the "style" of the teacher network particles in the format (size, marker_style). The default is (4, "diamond").
    p_style : tuple of the form (float, str), optional
        tuple containing the "style" of the studen network particles in the format (size, marker_style). The default is (1.5, "circle").
    eq_style : tuple of the form (float, float), optional
        tuple containing other "style" parameters, such as (mesh_opacity, length of lines). The default is (0.35,0.8).
    with_lineplots : bool, optional
        boolean indicating whether straight lines will be traced to every teacher particle. The default is True.
    double : bool, optional
        boolean indicating wheter a single or dual plot will be made. The default is False.

    Returns
    -------
    widget.HBox object or None (if not double)
        widget object to be displayed in a jupyter notebook.

    """
    teacher_list = isinstance(particle_positions_teacher, list)
    if not teacher_list:
      teacher_coords, particle_coords, equivariant_coords = get_transpose(particle_positions_teacher), get_transpose(particle_positions_model_L[0]), get_transpose(equivariant_space_limits)
    else:
      teacher_coords, particle_coords, equivariant_coords = get_transpose(particle_positions_teacher[0]), get_transpose(particle_positions_model_L[0]), get_transpose(equivariant_space_limits)
    plots = create_plot_list(teacher_coords, particle_coords, equivariant_coords, c_lims = c_lims, t_style = t_style , p_style = p_style, eq_style = eq_style, with_lineplots = with_lineplots)

    fig1 = go.Figure(data=plots)
    if not teacher_list:
      data_plots = [create_plot_list(teacher_coords, get_transpose(particle_positions_model_L[k]), equivariant_coords, c_lims = c_lims, t_style = t_style , p_style = p_style, eq_style = eq_style, with_lineplots = with_lineplots) for k in range(1,len(particle_positions_model_L))]
    else:
      data_plots = [create_plot_list(get_transpose(particle_positions_teacher[k]), get_transpose(particle_positions_model_L[k]), equivariant_coords, c_lims = c_lims, t_style = t_style , p_style = p_style, eq_style = eq_style, with_lineplots = with_lineplots) for k in range(1,len(particle_positions_model_L))]
    print(len(data_plots))
    frames = [go.Frame(data= data_plots[k] ,name=f'frame{k}') for k  in  range(len(data_plots))] #traces= [0]

    fig1.update(frames=frames)


    layout_config= dict(
        title=dict(text=title, x=0.5, font=dict(size=20), automargin=False),
        scene = dict(
              xaxis_title='x1',
              yaxis_title='x2',
              zaxis_title='x3',
              xaxis_showspikes= False,
              yaxis_showspikes= False,
              zaxis_showspikes= False,
              camera = dict(
                  up=dict(x=0, y=0, z=1),
                  center=dict(x=0, y=0, z=0),
                  eye=dict(x=1.25, y=-1.25, z=1.25)
              )
              ),
        width=600,
        height=450,
        autosize=False,
        legend=dict(
        yanchor="bottom",
        y=0.99,
        xanchor="left",
        x=0.01
      ))
    layout_config["width"], layout_config["height"] = 800, 600
    fig1.update_layout(layout_config)
    #fig1.update_layout(updatemenus=[dict(type="buttons",
    #                      buttons=[dict(label="Play",
    #                                    method="animate",
    #                                    args=[None, dict(frame=dict(redraw=True,fromcurrent=True, mode='immediate'))      ])])])


    def frame_args(duration):
      return {
            "frame": {"duration": duration},
            "mode": "immediate",
            "fromcurrent": True,
            "transition": {"duration": duration, "easing": "linear"},
            }


    sliders = [
        {"pad": {"b": 10, "t": 60},
        "len": 0.9,
        "x": 0.1,
        "y": 0,
        
        "steps": [
                    {"args": [[f.name], frame_args(0)],
                      "label": str(k),
                      "method": "animate",
                      } for k, f in enumerate(fig1.frames)
                  ]
        }
            ]

    fig1.update_layout(

        updatemenus = [{"buttons":[
                        {
                            "args": [None, frame_args(1)],
                            "label": "Play", 
                            "method": "animate",
                        },
                        {
                            "args": [[None], frame_args(0)],
                            "label": "Pause", 
                            "method": "animate",
                      }],
                        
                    "direction": "left",
                    "pad": {"r": 10, "t": 70},
                    "type": "buttons",
                    "x": 0.1,
                    "y": 0,
                }
            ],
            sliders=sliders
        )
    fig1.update_layout(sliders=sliders)
    fig1.update_layout(scene = dict(xaxis=dict(range=[c_lims[0], c_lims[1]], autorange=False),
                               yaxis=dict(range=[c_lims[0], c_lims[1]], autorange=False),
                               zaxis=dict(range=[c_lims[0], c_lims[1]], autorange=False)))
    #fig.update_layout(scene = dict(xaxis=dict(range=[x.min(), x.max()], autorange=False),
    #                           yaxis=dict(range=[y.min(), y.max()], autorange=False),
    #                           zaxis=dict(range=[z.min(), z.max()], autorange=False)))
    fig1.show()

