import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import os
import pandas as pd
import plotly.graph_objects as go
import seaborn as sns

from PIL import Image

from pathlib import Path
from plotly.graph_objs import Layout, Scene, XAxis, YAxis, ZAxis, FigureWidget, Data
from pylab import cm
from scipy.linalg import schur, eigvals
from typing import List

def create_trait_eval_mat(A: np.array,X: pd.DataFrame,T: str, types = None):
    """
    Get subsets of agents by a specific trait and compute the eval between those subsets
    The trait matrix X is in a pd Dataframe but could also just be a numpy array
    """
    if not types:
        types = list(set(X[T]))
    N = len(types)
    A_sub = np.zeros((N,N))
    
    for i in range(N):
        agents_i = np.where(X[T] == types[i])[0]
        for j in range(N):
            agents_j = np.where(X[T] == types[j])[0]
            f_i_j = A[agents_i,:][:,agents_j].mean()
            A_sub[i][j] = f_i_j
    
    np.fill_diagonal(A_sub,0.0)
            
    return A_sub

def plot_image_list_vertical(images, save_path):
    """
    Plot set of images vertically
    Args:
        images ([type]): [description]
        save_path ([type]): [description]
    """
    widths, heights = zip(*(i.size for i in images))
    max_width = max(widths)
    total_height = sum(heights)
    new_im = Image.new('RGB', (max_width, total_height))
    y_offset = 0
    for im in images:
        new_im.paste(im, (0 , y_offset))
        y_offset += im.size[1]
    new_im.save(save_path)

def plot_image_list_horizontal(images, save_path):
    widths, heights = zip(*(i.size for i in images))
    total_width = sum(widths)
    max_height = max(heights)
    new_im = Image.new('RGB', (total_width, max_height))
    x_offset = 0
    for im in images:
        new_im.paste(im, (x_offset, 0))
        x_offset += im.size[0]
    new_im.save(save_path)

def get_discgame_angles(disc_game):
    x_axis = np.array([1, 0])
    
    angles = []
    radi = []
    for i, x_i in enumerate(disc_game):
        angle = np.rad2deg((np.arctan2(x_i[1],x_i[0])) % (2*np.pi))
        radius = np.sqrt(x_i[0]**2 + x_i[1]**2)
        angles.append(angle)
        radi.append(radius)
    return angles, radi

class GameUtil:

    def __init__(self, output_dir: str) -> None:
        self.output_dir = output_dir

    def generate_pokemon_3panel(self, A: npt.ArrayLike, X: npt.ArrayLike, exp_dir: str,   n: int = 4, cluster_n: int = 2, attack_chart_path: str = "./data/experiment4/chart.csv"):
        """
        The 3 panel figure for pokemon it has
        1. Disc Game with representative agents
        2. Performance matrix
        3. Attack Matrix
        """
        # Assumes that the pokemon game is already loading into self.games
        disc_games, eigs, U_vecs, Q_vecs = self.get_disc_games(A, n = n)

        traits = ['Type 1',
                'Type 2',
                    'HP',
                    'Attack',
                    'Defense',
                    'Sp. Atk',
                    'Sp. Def',
                    'Speed',
                    'Generation']
        
        disc_game_cluster = disc_games[cluster_n]

        ratings = A.mean(axis=1)
        
        X_df = pd.DataFrame(X,columns=traits)
        
        data_1 = np.concatenate((disc_games[cluster_n],ratings.reshape((len(ratings),1)), X),axis=1)
        
        df_1 = pd.DataFrame(data_1,columns=["x","y","ratings"] + traits)
        
        mean_df = df_1.groupby(["Type 1"]).mean()
        
        types = ["Normal","Fire","Water","Electric","Grass","Ice","Fighting","Poison","Ground","Flying","Psychic",
                 "Bug","Rock","Ghost","Dragon","Dark","Steel","Fairy"]
                
        mean_df.index = types

        mean_disc_game = mean_df[["x","y"]].to_numpy() # This is the x,y coords 

        colors = ["grey"]*mean_df.shape[0]
        sizes = [20]*mean_df.shape[0]
        indexes = list(mean_df.index)

        grass_index = indexes.index("Grass")
        fire_index = indexes.index("Fire")
        water_index = indexes.index("Water")

        colors[indexes.index("Grass")] = "green"
        sizes[indexes.index("Grass")] = 400
        colors[indexes.index("Fire")] = "red"
        sizes[indexes.index("Fire")] = 400
        colors[indexes.index("Water")] = "blue"
        sizes[indexes.index("Water")] = 400
        colors[indexes.index("Bug")] = "orange"
        sizes[indexes.index("Bug")] = 400
        colors[indexes.index("Ground")] = "brown"
        sizes[indexes.index("Ground")] = 400
        colors[indexes.index("Rock")] = "silver"
        sizes[indexes.index("Rock")] = 400
        
        plt.figure(figsize=(8,8.46))
        self.plot_disc_game(disc_game_cluster,rating=np.array(["grey"]*len(disc_game_cluster)),sizes=np.array([5]*len(disc_game_cluster)),set_colorbar=False,color_name="Type", trim_axis=True) 
        plt.scatter(mean_disc_game[:,0],mean_disc_game[:,1],c=np.array(colors),sizes=np.array(sizes))
        
        fire_patch = mpatches.Patch(color='red', label='Fire')
        water_patch = mpatches.Patch(color='blue', label='Water')
        grass_patch = mpatches.Patch(color='green', label='Grass')
        bug_patch = mpatches.Patch(color='orange', label='Bug')
        ground_patch = mpatches.Patch(color='brown', label='Ground')
        rock_patch = mpatches.Patch(color='silver', label='Rock')
        
        # Add lines to make RPS more visible
        grass_x, grass_y = mean_disc_game[grass_index,:]
        fire_x, fire_y = mean_disc_game[fire_index,:]
        water_x, water_y = mean_disc_game[water_index,:]

        plt.plot([water_x,grass_x],[water_y,grass_y],color="C1")
        plt.plot([fire_x,grass_x],[fire_y,grass_y],color="C1")
        plt.plot([water_x,fire_x],[water_y,fire_y],color="C1")
        plt.fill([fire_x,water_x,grass_x],[fire_y,water_y,grass_y],color="C1",alpha=0.3)
        
        heat_map_font = 22
        plt.title("Disc game 2 RPS relationships", fontsize=18)
        #plt.legend(handles=[rock_patch, ground_patch, water_patch, grass_patch, bug_patch,fire_patch], fontsize=18)
        plt.savefig(exp_dir + "pokemon_rpss.png",bbox_inches='tight',pad_inches = 0.1, dpi = 200)
        plt.close()

        # Now create heat map
        def get_type(t_i, types):
            return types[t_i]

        X_df["types"] = X_df.apply(lambda x: get_type(int(x["Type 1"]),types),axis=1)
        
        sub_types = ["Rock","Ground","Water","Grass","Bug","Fire"]
        
        A_2 = Q_vecs[cluster_n] @ U_vecs[cluster_n] @ Q_vecs[cluster_n].T

        A_types = create_trait_eval_mat(A_2,X_df,"types",sub_types)

        sns.set(font_scale=2.0)
        self.plot_pokemon_heatmap(A_types, np.array(sub_types),title= "Pokemon Performance Matrix", figsize=(8,8))
        
        plt.savefig(exp_dir + "pokemon_heatmap.png",bbox_inches='tight',pad_inches = 0.1, dpi = 200)
        plt.close()

        plt.figure(figsize=(8,8))
        plt.title("Skew Attack Matrix", fontsize = 18)
        cmap = cm.get_cmap('PiYG', 7)
        attack_chart = pd.read_csv(attack_chart_path,index_col=0)
        sub_types = ["Rock","Ground","Water","Grass","Bug","Fire"]
        sub_df = attack_chart[sub_types].loc[sub_types]
        skew = sub_df - sub_df.T
        
        sns.heatmap(skew,annot=skew.to_numpy(),xticklabels=sub_types,yticklabels=sub_types,linewidths=0.5,fmt="",cbar=False,cmap=cmap, annot_kws={"size": heat_map_font}) 
        plt.savefig(exp_dir + "attack_matrix.png",bbox_inches='tight',pad_inches = 0.1, dpi = 200)
        plt.close()

        # Put into horizontal image
        images = [Image.open(exp_dir + "pokemon_rpss.png"), Image.open(exp_dir + "pokemon_heatmap.png"), Image.open(exp_dir + "attack_matrix.png")]
        plot_image_list_horizontal(images, self.output_dir + "final_pokemon_attack.png")
        
    def get_blotto_game_metrics(self, path: str, tol: float = 0.05) -> dict:
        """
        Go through each file in the provided path and parse all the
        game files into a list
        Args:
            path (str): [description]

        Returns:
            List[GameGraph]: [description]
        """
        print(f"Creating blotto metrics with error tol of {tol}")
        
        game_metrics = {}
        for file in os.listdir(path):
            file_path = path + file

            file_info = file_path.split("_")
            N = file_info[3]
            payout = file_info[2]
            K = file_info[1]
            print("Payout")
            print(payout)
            print(K)
            if os.path.isfile(file_path):
                continue

            print(f"Loading game from {file_path}")          
            pwin_path = file_path + "/P.npy"
            trait_path = file_path + "/X.npy"
            trait_desc_path = file_path + "/traits.npy"
            X = np.load(trait_path)
            P = np.load(pwin_path)
            A = P - 0.5
            print("game anal")
            metrics = self.get_game_stats(X, A)
            print("load game anal")
            print("Low rank")
            k = self.get_low_rank_number(A = A, error_tol = tol)
            metrics["low_rank"] = k
            metrics["N"] = N
            metrics["K"] = K
            metrics["payout"] = payout
            game_metrics[file] = metrics.copy()

        return game_metrics 
    
    def get_disc_game_images(self, n: int, path: str, colors: npt.ArrayLike, eigs: List, disc_games: list, order: List = [], name = "horizontal_disc_games.png", conditions: dict = {}) -> List:
        
        Path(path).mkdir(parents=True, exist_ok=True) 
        
        images = []
        for d_i in range(n):
            fig = plt.figure(figsize=(10,10))
            sizes = np.array([400]*len(colors))
            eig_val = eigs[d_i]
            self.plot_disc_game(disc_games[d_i], sizes = sizes ,rating = colors, title ="", trim_axis = True, order = order)
            if len(conditions) > 0:
                self.plot_rep_blotto_points(disc_games[d_i], conditions)
            save_path = path + f"discgame_{d_i + 1}.png" 
            plt.savefig(save_path,bbox_inches='tight',pad_inches = 0.1, dpi = 200)
            plt.close()
            images.append(Image.open(save_path))
        
        return images  
        
    def get_disc_games(self, A: npt.ArrayLike, n: int = 4) -> tuple:
        """
        Return n number of disc games
        from the current game
        """
        A = A.copy()
            
        U, Q = schur(A, output="real")
        
        disc_games = []
        eigs = []
        U_vecs = [] # Used for debugging
        Q_vecs = [] # Used for debugging

        eigs_sort = []
        for i, v in enumerate(np.abs(eigvals(U))):
            if (i % 2) ==0:
                eigs_sort.append(v)
        
        eigs_sort = np.argsort(eigs_sort)[::-1]
        
        for _, k in enumerate(eigs_sort[:n]):
            print(k)
            #print(eig)
            
            k_s , k_e = 2*k , 2*(k + 1)

            Q_2k = Q[:,k_s:k_e]

            U_2k = U[k_s:k_e,k_s:k_e]

            if U_2k[1,0] > 0:
                U_2k = U_2k.T
                Q_2k = np.concatenate((Q_2k[:,[1]],Q_2k[:,[0]]),axis=1) # Switch order of eigenvectors if eigenvalues are reversed
            
            w_2k = np.abs(eigvals(U_2k))[0]
            
            print(f"Eigenvalue {k + 1} = {w_2k}")
            
            if w_2k < 0.1: #TODO Handle samll eigen value case better. Perhaps inform user. 
                w_2k = 1

            Y_2k = (Q_2k.T @ A).T / np.sqrt(w_2k) # Projection of A --> A_2k
            
            disc_games.append(Y_2k)

            eigs.append(w_2k)

            U_vecs.append(U_2k)
            
            Q_vecs.append(Q_2k)

        return disc_games, eigs, U_vecs, Q_vecs
    
    def get_game_stats(self, X: npt.ArrayLike, A: npt.ArrayLike) -> dict:
        """
        Get the transitive , cyclic and intransitive parts of the current game
        """

        # AGENT Stats
        X: npt.ArrayLike = X.copy()
        num_agents: int = X.shape[0]
        num_attributes: int = X.shape[1]
        A = A.copy()
        N = A.shape[0]

        ratings = A.mean(axis=1)
        Ft = np.zeros(A.shape)
        for i in range(A.shape[0]):
            for j in range(i,A.shape[0]):
                Ft[i][j] = ratings[i] - ratings[j]
                Ft[j][i] = -Ft[i][j]
        
        Fc = A - Ft
        norm = N**2
      
        ft_norm, fc_norm = np.linalg.norm(Ft) / norm, np.linalg.norm(Fc) / norm

        game_stats = {
            
        }

        game_stats["ft"] = np.round(ft_norm , 3)
        game_stats["fc"] = np.round(fc_norm, 3)
        game_stats["intrans"] = np.round(fc_norm / float(ft_norm) ,3)
        game_stats["num_agents"] = num_agents
        game_stats["num_attributes"] = num_attributes
        
        return game_stats
    
    def get_low_rank_number(self, A: npt.ArrayLike, error_tol: float = 0.05) -> int:
        """
        Compute the low rank approximation of the eval matrix of
        the current game to get how many disc games are necessary
        to use
        """
        assert 0 <= error_tol < 1, 'errTol not in [0,1)'

        u, s, vh = np.linalg.svd(A)
        totalSquareSum = np.sum(s ** 2)
        requiredSquareSum = ((1 - error_tol) * totalSquareSum) / 2
        curSquareSum = 0
        k = 0
        while curSquareSum < requiredSquareSum: 
            assert 2 * k < len(s), 'requiredSquareSum over totalSquareSum?'
            curSquareSum += (s[2 * k] ** 2)
            k+=1
            
        print(f"Number of requried planes = {k} to have an error Tol of {error_tol}")
        return k 
    
    def plot_3d_strat_space(self, X: npt.ArrayLike, colors: npt.ArrayLike, title: str, width = 1600, height = 1600, 
                            colorscale = "twilight"):
        """
        This plot asssumes that you are working with a 3d trait space
        """
        X = X.copy()
    
        fig = go.Figure()
       
        trace1 = go.Scatter3d(
            x=X[:,0],
            y=X[:,1],
            z=X[:,2],
            mode='markers',
            marker=dict(
                size=16,
                color=colors,                # set color to an array/list of desired values
                colorscale=colorscale,   # choose a colorscale
                opacity=1.0,
                colorbar=dict(thickness=40,
                              tickfont=dict(size=36, color="black"))
            )
        )
        camera = dict(
            eye=dict(x=1.0, y=1.0, z=1.0)
        )
        
        tickfont = dict(color = "black", size = 18)
        layout = Layout(
            title=title,
            width=width,
            height=height,
            showlegend=False,
            scene=Scene(
                xaxis=XAxis(title=dict(text="K1", font=dict(size=36)), tickfont= tickfont),
                yaxis=YAxis(title=dict(text="K2", font=dict(size=36)), tickfont = tickfont),
                zaxis=ZAxis(title=dict(text="K3", font=dict(size=36)), tickfont = tickfont)
            )
        )

        #fig.add_trace(trace1)

        # tight layout
        #fig.update_layout(layout=layout,title=title, width = width, height = height)
        fig = FigureWidget(data=Data([trace1]), layout=layout)
        #fig.update_scenes(xaxis_visible=False, yaxis_visible=False,zaxis_visible=False )
        fig.update_layout(scene_camera=camera)
        return fig
    
    def plot_disc_game(self,A_2k: npt.ArrayLike,rating: npt.ArrayLike, sizes: List = [], 
                       axis = None, plot_strat: bool = False, plot_vector_field: bool = True,title: str = "", set_colorbar: bool = False,
                       color_name: str = "Rating", colorbar_fmt = None, colorbar_ticks = None, trim_axis: bool = True, order: List = []):
        """
        Each disc game is a simple 2d plot with a surrounding vector field
        Args:
            A_2k (npt.ArrayLike): [description]
            rating (npt.ArrayLike): rating for each point used to color them
            axis ([type]): matplotlib object (optional) 
            title (str, optional): [description]. Defaults to "".
        """ 
        print(f"Plotting disc game with shape = {A_2k.shape}")
        if axis is None:
            axis = plt
        
        if len(sizes) == 0: # If not passed provide default setting
            sizes = np.array([100]*A_2k.shape[0])
      
        min_x,min_y,max_x,max_y = A_2k[:,0].min(),A_2k[:,1].min(),A_2k[:,0].max(),A_2k[:,1].max()
        scale = 1.1
        max_val = scale*max([np.abs(min_x),np.abs(min_y),max_x,max_y])

        axis.axis([-max_val,max_val,-max_val,max_val])
        skip_size = 4

        if plot_vector_field:
            self.plot_vector_field(skip_size=skip_size,x_start=-max_val,x_finish=max_val,
                                y_start=-max_val,y_finish=max_val,axis=axis)

        if len(title) > 0:
            if hasattr(axis,"set_title"):
                axis.set_title(title)
            else:
                axis.title(title)
        
        if len(order) == 0:
            order = [i for i in range(A_2k.shape[0])]     

        axis.scatter(A_2k[order,0],A_2k[order,1],c=rating[order],s=sizes[order])

        if plot_strat: 
            for idx, p in enumerate(A_2k):
                if (sizes[idx] / 1000.0) > 0.05:
                    axis.text(p[0],p[1],str(sizes[idx] / 1000))
        
        if set_colorbar: 
            if colorbar_fmt:
                cb = plt.colorbar(sc,format=colorbar_fmt,ticks=colorbar_ticks)
            else:
                cb = plt.colorbar(sc)
            cb.set_label(color_name) 
            
        if trim_axis:
            plt.gca().axes.get_yaxis().set_visible(False)
            plt.gca().axes.get_xaxis().set_visible(False)

        if hasattr(axis,"set_title"):
            axis.set_title(title)
        else:
            axis.title(title)
            
        return axis
     
    def plot_eigenvalues(self, eigs: List[List[float]], labels: List, save_path: str, figsize: tuple = (12,6)):
        
        print("Running plot eigenvalues") 
        print(len(eigs))
        plt.figure(figsize=figsize)
        x = [i for i in range(1,len(eigs[0]) + 1)]
        log_x = np.log(x)
        for i, eig_i in enumerate(eigs):
            log_e = np.log(eig_i)
            plt.scatter(log_x, log_e, label = labels[i])
        
        plt.legend()
        plt.xlabel("Disc Game")
        plt.ylabel("Eigenvalue")
        plt.savefig(save_path)
        plt.close()

    def plot_horizontal_disc_game_colored_by_traits(self, path: str, colors: List[npt.ArrayLike], eigs: List, disc_game: list, name = "horizontal_disc_games.png"):
        
        
        Path(path).mkdir(parents=True, exist_ok=True) 
        
        images = []
        for c_i, c in enumerate(colors):
            fig = plt.figure(figsize=(10,10))
            sizes = np.array([400]*len(colors))
            eig_val = eigs[c_i]
            #self.plot_disc_game(disc_games[d_i], sizes = sizes ,rating = colors, title = f"Disc game {n + 1} eig = {eig_val:.0f}", trim_axis = True)
            self.plot_disc_game(disc_game, sizes = sizes ,rating = c, title ="", trim_axis = True)
            save_path = path + f"discgame_k={c_i + 1}.png" 
            plt.savefig(save_path,bbox_inches='tight',pad_inches = 0.1, dpi = 200)
            plt.close()
            images.append(Image.open(save_path))
            
        plot_image_list_horizontal(images, save_path= path + name)
        
    def plot_horizontal_disc_games(self, n: int, path: str, colors: npt.ArrayLike, eigs: List, disc_games: list, point_size: int = 400,
                                   image_size: int = 10, order: List = [],name = "horizontal_disc_games.png"):
        
        
        Path(path).mkdir(parents=True, exist_ok=True) 
        
        images = []
        for d_i in range(n):
            fig = plt.figure(figsize=(image_size,image_size))
            sizes = np.array([point_size]*len(colors))
            eig_val = eigs[d_i]
            #self.plot_disc_game(disc_games[d_i], sizes = sizes ,rating = colors, title = f"Disc game {n + 1} eig = {eig_val:.0f}", trim_axis = True)
            self.plot_disc_game(disc_games[d_i], sizes = sizes ,rating = colors, title ="", trim_axis = True, order = order)
            save_path = path + f"discgame_{d_i + 1}.png" 
            plt.savefig(save_path,bbox_inches='tight',pad_inches = 0.1, dpi = 200)
            plt.close()
            images.append(Image.open(save_path))
            
        plot_image_list_horizontal(images, save_path= path + name)
    
    def plot_pokemon_heatmap(self,A: npt.ArrayLike, types,title: str = "Performance Matrix",figsize: tuple = (10,10), heat_map_font: float = 22):
    
        plt.figure(figsize=figsize)
        plt.title(title, fontsize = 18)
        axis_labels = types
        sns.heatmap(A,xticklabels=axis_labels, yticklabels=axis_labels,annot=True,linewidths=.5,fmt='.2g',cmap='PiYG', cbar= False, annot_kws={"size": heat_map_font}) 
    
    def plot_rep_blotto_points(self, disc_game: npt.ArrayLike, conditions: dict = {}):
        """
        Plot the representative points on a disc game where the different rep points are given 
        by the conditions
        Args:
            X (npt.ArrayLike): _description_
            disc_game (npt.ArrayLike): _description_
            conditions (dict): _description_
        """

        #handles = []
        for k, v in conditions.items():
            conds, color , geo = v
            for c_i, cond in enumerate(conds):
                point_cluster = disc_game[cond].mean(axis=0)
                plt.scatter(point_cluster[0], point_cluster[1], c = color, sizes = [500]*len(point_cluster), marker=geo)
    
    @staticmethod
    def plot_vector_field(skip_size,x_start: int = 1,x_finish: int = 1, y_start: int = 1, y_finish: int = 1,axis=None):
        """
        Plot a vector field around a certain center
        Args:
            center ([type]): [description]
        """
        x_vec,y_vec = np.meshgrid(np.linspace(x_start,x_finish,),np.linspace(y_start,y_finish))
        dir_u = y_vec/np.sqrt(x_vec**2 + y_vec**2)
        dir_v = -x_vec/np.sqrt(x_vec**2 + y_vec**2)
        skip = (slice(None, None, skip_size), slice(None, None, skip_size))

        if axis is not None:
            axis.quiver(x_vec[skip],y_vec[skip],dir_u[skip],dir_v[skip], alpha = 0.2, scale = None,color="black")
        else:
            plt.quiver(x_vec[skip],y_vec[skip],dir_u[skip],dir_v[skip], alpha = 0.2, scale = None,color="black")
