import json
from unicodedata import name
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import os
import pandas as pd
import seaborn as sns
import toml

from PIL import Image
from pathlib import Path
from typing import List
from scipy.io import savemat
from scipy.optimize import curve_fit


from principal_tradeoff_analysis.util import GameUtil, plot_image_list_horizontal, plot_image_list_vertical, get_discgame_angles

class PTAExperiment:
    """
    Main Driver Class for running the PTA experiments
    """

    @classmethod
    def load_from_config(cls,config_path: str) -> object:
        """
        Instantiate the current class using a config file
        """
        if config_path.split(".")[-1] == "toml":
            config = toml.load(config_path)
        else:
            raise Exception("Only Support TOML config")

        return cls(**config)

    def __init__(self,output_path: str, experiments: dict, run_exp_1: bool, run_exp_2: bool, run_exp_3: bool, run_exp_4: bool, run_eigenvalue_plot: bool, **kwags) -> None:

        self.experiments: dict = experiments 
        self.pta_util = GameUtil(output_path)

        self._output_path = output_path
        # Store boolean values whether to run the experiments or not
        self._run_eigenvalue_plots = run_eigenvalue_plot
        self._run_exp_1 = run_exp_1
        self._run_exp_2 = run_exp_2
        self._run_exp_3 = run_exp_3
        self._run_exp_4 = run_exp_4

    def plot_phase_vs_trait(self, disc_game: npt.ArrayLike, trait: npt.ArrayLike, trait_name: str, save_path: str):    
        """
        Helper Method for 
        Args:
            disc_game (npt.ArrayLike): _description_
            trait (npt.ArrayLike): _description_
            trait_name (str): _description_
            save_path (str): _description_
        """

        phase, radi = get_discgame_angles(disc_game)
        
        pl_cet= [[0.0, '#2e21ea'],
        [0.05, '#571ef4'],
        [0.1, '#7d31f8'],
        [0.15, '#a43efb'],
        [0.2, '#ce45fa'],
        [0.25, '#ef55f1'],
        [0.3, '#f977d8'],
        [0.35, '#fb9db5'],
        [0.4, '#fbbe90'],
        [0.45, '#fbda67'],
        [0.5, '#f0ed35'],
        [0.55, '#d2e919'],
        [0.6, '#acdb12'],
        [0.65, '#83cd0e'],
        [0.7, '#56be0c'],
        [0.75, '#31ac28'],
        [0.8, '#3e9755'],
        [0.85, '#447f83'],
        [0.9, '#3465ad'],
        [0.95, '#2646cf'],
        [1.0, '#2b24e8']]

        mid =int(len(pl_cet) / 2)

        for i in range(mid):
            pl_cet[i][1] = pl_cet[-(i+ 1)][1]
        
        plt.figure(figsize=(10,10))
        plt.title(f"{trait_name} vs Radius")
        plt.scatter(phase, trait)
        plt.savefig(save_path)
        plt.close()
    
    def create_full_blotto_plot(self, n: int, tmp_dir: str, X: npt.ArrayLike, F: npt.ArrayLike, figure_name: str, 
                                plot_phase: bool = True, plot_radius: bool = True, conditions: dict = {}, disc_game_indexes: List = [], disc_game_as_row: bool = True):
        """
        Create blotto figures
        Figures For Blotto Example 1 and 2
        """
        
        disc_games, eigs, U_vecs, Q_vecs = self.pta_util.get_disc_games(F, n = n)

        savemat(tmp_dir + "eigs.mat", {"eigs": eigs})

        if len(disc_game_indexes) > 0: # Only look at subset
            disc_games = [disc_games[i] for i in disc_game_indexes]
            eigs = [eigs[i] for i in disc_game_indexes]
            n = len(disc_game_indexes) # Only interested in # of disc games given by indexes

        ratings = F.mean(axis=1)
        X = X.copy()
        
        # First do the color by rating
        rating_save_name = "exp2_horizonal_ratings.png"
        print(eigs)
        print(n)
        order = np.argsort(ratings)
        if disc_game_as_row:
            images = self.pta_util.get_disc_game_images(n=n, path=tmp_dir, name=rating_save_name, colors=ratings, disc_games=disc_games, eigs = eigs, conditions = conditions, order = order)
            plot_image_list_vertical(images=images, save_path=tmp_dir + rating_save_name)
        else:
            self.pta_util.plot_horizontal_disc_games(n=n, path=tmp_dir, name=rating_save_name, colors=ratings, disc_games=disc_games, eigs = eigs, order= order)

        # Then make a row for each allocation
        alloc_images = []
        for k in range(X.shape[1]):
            colors = X[:,k]
            order = np.argsort(colors)
            save_name = f"horizontal_alloc{k + 1}.png"
            alloc_images.append(tmp_dir + save_name)
            if disc_game_as_row:
                images = self.pta_util.get_disc_game_images(n=n, path=tmp_dir, name=rating_save_name, colors=colors, disc_games=disc_games, eigs = eigs, conditions = conditions, order = order)
                plot_image_list_vertical(images=images, save_path=tmp_dir + save_name)
            else:
                self.pta_util.plot_horizontal_disc_games(n=n, path = tmp_dir, name=save_name, colors=colors, disc_games=disc_games, eigs = eigs, order = order)
            plt.close()

        if plot_phase:
            images = [] 
            for d_i in range(n):
                disc_game = disc_games[d_i]
                colors, radi = get_discgame_angles(disc_game)
                colors = [np.round(c,2) for c in colors]
                
                pl_cet= [[0.0, '#2e21ea'],
                [0.05, '#571ef4'],
                [0.1, '#7d31f8'],
                [0.15, '#a43efb'],
                [0.2, '#ce45fa'],
                [0.25, '#ef55f1'],
                [0.3, '#f977d8'],
                [0.35, '#fb9db5'],
                [0.4, '#fbbe90'],
                [0.45, '#fbda67'],
                [0.5, '#f0ed35'],
                [0.55, '#d2e919'],
                [0.6, '#acdb12'],
                [0.65, '#83cd0e'],
                [0.7, '#56be0c'],
                [0.75, '#31ac28'],
                [0.8, '#3e9755'],
                [0.85, '#447f83'],
                [0.9, '#3465ad'],
                [0.95, '#2646cf'],
                [1.0, '#2b24e8']]

                #mid =int(len(pl_cet) / 2)

                #for i in range(mid):
                #    pl_cet[i][1] = pl_cet[-(i+ 1)][1]
                
                #fig = self.pta_util.plot_3d_strat_space(X, colors=colors, title=f"Disc Game {d_i + 1} phase", colorscale=pl_cet)
                fig = self.pta_util.plot_3d_strat_space(X, colors=colors, title="", colorscale=pl_cet, width=1500, height=1590)
                path = tmp_dir + f"disc_game_{d_i + 1}_phase.png"
                fig.write_image(path)
                images.append(Image.open(path))

            phase_img_path = tmp_dir + f"horizontal_phase_strats.png"
            if disc_game_as_row:
                plot_image_list_vertical(images, phase_img_path)
            else:
                plot_image_list_horizontal(images, phase_img_path)

        if plot_radius:
            images = [] 
            for d_i in range(n):
                disc_game = disc_games[d_i]
                colors, radi = get_discgame_angles(disc_game)
                
                pl_cet= [[0.0, '#2e21ea'],
                [0.05, '#571ef4'],
                [0.1, '#7d31f8'],
                [0.15, '#a43efb'],
                [0.2, '#ce45fa'],
                [0.25, '#ef55f1'],
                [0.3, '#f977d8'],
                [0.35, '#fb9db5'],
                [0.4, '#fbbe90'],
                [0.45, '#fbda67'],
                [0.5, '#f0ed35'],
                [0.55, '#d2e919'],
                [0.6, '#acdb12'],
                [0.65, '#83cd0e'],
                [0.7, '#56be0c'],
                [0.75, '#31ac28'],
                [0.8, '#3e9755'],
                [0.85, '#447f83'],
                [0.9, '#3465ad'],
                [0.95, '#2646cf'],
                [1.0, '#2b24e8']]

                fig = self.pta_util.plot_3d_strat_space(X, colors=radi, colorscale="Viridis", title="", width=1500, height=1590)
                path = tmp_dir + f"disc_game_{d_i + 1}_phase.png"
                fig.write_image(path)
                images.append(Image.open(path))

            radius_img_path = tmp_dir + f"horizontal_radius_strats.png"
            if disc_game_as_row:
                plot_image_list_vertical(images, radius_img_path)
            else:
                plot_image_list_horizontal(images, radius_img_path)

        images = [Image.open(tmp_dir + rating_save_name)] + [Image.open(img) for img in alloc_images] 
        
        if plot_phase:
            images += [Image.open(phase_img_path)]
        
        if plot_radius:
            images += [Image.open(radius_img_path)]
        
        if disc_game_as_row: 
            plot_image_list_horizontal(images, self._output_path + figure_name)
        else:
            plot_image_list_vertical(images, self._output_path + figure_name)                            
     
    def create_pokemon_discgame_plots(self, n: int, tmp_dir: str, X: npt.ArrayLike, F: npt.ArrayLike, figure_name: str, image_size: int = 10,
                                      point_size: int = 400, disc_game_indexes: List = []):
        
        """
        Create the disc games for pokemon
        These are currently used in Figure 3 where we show disc game 1,2 and 4
        """ 
        
        disc_games, eigs, U_vecs, Q_vecs = self.pta_util.get_disc_games(F, n = n)

        if len(disc_game_indexes) > 0: # Only look at subset
            disc_games = [disc_games[i] for i in disc_game_indexes]
            eigs = [eigs[i] for i in disc_game_indexes]
            n = len(disc_game_indexes) # Only interested in # of disc games given by indexes

        ratings = F.mean(axis=1)
        X = X.copy()
        
        # First do the color by rating
        rating_save_name = "horizonal_ratings.png"
        print(eigs)
        print(n)
        self.pta_util.plot_horizontal_disc_games(n=n, path=tmp_dir, name=rating_save_name, colors=ratings, disc_games=disc_games, eigs = eigs, image_size=image_size, point_size=point_size)
        
        disc_images = [] 
        # Color Disc Game 2 by Type and Generation 
        for x_j in [0,8]:
            d_i = 1
            color = X[:,x_j]
            plt.figure(figsize=(image_size, image_size))
            sizes = np.array([point_size]*len(color))
            self.pta_util.plot_disc_game(disc_games[d_i], sizes = sizes ,rating = color, title ="", trim_axis = True)
            save_path = tmp_dir + f"discgame_{d_i + 1}_color={x_j + 1}.png" 
            plt.savefig(save_path,bbox_inches='tight',pad_inches = 0.1, dpi = 200)
            plt.close()
            disc_images.append(Image.open(save_path))

        for x_j in [8]:
            d_i = 2
            color = X[:,x_j]
            plt.figure(figsize=(image_size, image_size))
            sizes = np.array([point_size]*len(color))
            self.pta_util.plot_disc_game(disc_games[d_i], sizes = sizes ,rating = color, title ="", trim_axis = True)
            save_path = tmp_dir + f"discgame_{d_i + 1}_color={x_j + 1}.png" 
            plt.savefig(save_path,bbox_inches='tight',pad_inches = 0.1, dpi = 200)
            plt.close()
            disc_images.append(Image.open(save_path))
        
        alloc_path = tmp_dir + "allocation_colors.png" 
        plot_image_list_horizontal(disc_images,alloc_path) 
                
        images = [Image.open(tmp_dir + rating_save_name)] + [Image.open(alloc_path)] 
        
        plot_image_list_horizontal(images, self._output_path + figure_name)                            
    
    def run(self):
        for name, exp in self.experiments.items():
            print(f"Looking at experiment {name}")
            if name == "experiment1":
                if self._run_exp_1:
                    print("Running Experiment 1")
                    self.run_experiment1(exp)
                else:
                    print("Skipping the first Experiment")

            elif name == "experiment2":
                if self._run_exp_2:
                    print("Running Experiment 2")
                    self.run_experiment2(exp)
                else:
                    print("Skipping the second Experiment")

            elif name == "experiment3":
                if self._run_exp_3:
                    print("Running Experiment 3")
                    self.run_experiment3(exp)
                else:
                    print("Skipping the third Experiment")

            elif name == "experiment4":
                if self._run_exp_4:
                    print("Running Experiment 4")
                    self.run_experiment4(exp)
                else:
                    print("Skipping the fourth Experiment")
            
            elif name == "plot_eigenvalues":
                if self._run_eigenvalue_plots:
                    self.run_eigenvalue_plot(exp)
                
    def run_eigenvalue_plot(self, vars: dict):
        """
        This is for an accompanying plot in the appendix
        It will produce a plot of the eigenvalues vs disc game 
        for each of the passed in games. 
        The vars dictionary takes the form shown given in the toml file 
        for "experiments.plot_eigenvalues"
        """

        if vars["create_game_data"]:
            from principal_tradeoff_analysis.blotto import create_blotto_data
            games = [
                (3, [1,1,1], 45), 
                (3, [1,2,4], 45), 
                (3, [2,3,4], 45), 
                (4, [1,1,1,1], 20), 
                (4, [1,2,3,4], 20), 
            ]
            for K, payoff, N in games:
                game_path = vars["data_path"] + "games/" + f"blotto_{K}_{N}_payoff={payoff}/"
                Path(game_path).mkdir(parents=True, exist_ok=True)
                A, P, X = create_blotto_data(K=K ,N=N, payoff=payoff, game_type="weighted")
                F = P - 0.5
                np.save(game_path + "F.npy", F)
                np.save(game_path + "P.npy", P)
                np.save(game_path + "X.npy", X)
                
        games = vars["games"]
        n = vars["D"]
        labels = vars["labels"]
        output_path = "./output/" + "eigenvalue_compare.png"
        game_eigs = []
        for game_path in games:
            print(f"Getting eigen values for {game_path}")
            F = np.load(game_path + "F.npy")
            disc_games, eigs, _ , _ = self.pta_util.get_disc_games(F, n = n)
            eigs = np.array(eigs)
            game_eigs.append(eigs)
        
         
        self.pta_util.plot_eigenvalues(game_eigs, labels, output_path, figsize=(8,6))
    
    def run_experiment1(self, vars: dict) -> None:
        """
        This is the experiment for plotting intransitiviy vs discgames and complexity (fro norm) vs discgames
        Currently this has been relugated to the appendix
        """
        data_path = vars["data_path"] + "experiment1/"
        games_path = data_path + "games/"

        Path(games_path).mkdir(parents=True, exist_ok= True)

        if vars["create_blotto_data"]:
            from principal_tradeoff_analysis.blotto import create_blotto_games
            blotto_params = vars["blotto_params"]
            n = blotto_params["n_range"]
            payouts = blotto_params["payouts"]
            ks = blotto_params["k"]
            s1 = 15 if n[1] > 15 else n[1] # Some logic to change the x tick for N > 15 because of the increased runtime
            n_s = [n_i for n_i in range(n[0], s1)]
            if n[1] > 15:
                n_s += [n_i for n_i in range(s1, n[1] + 1, 5)]
            create_blotto_games(games_path, ks=ks, payouts = payouts, n_s=n_s)
        
        if vars["create_blotto_metric_data"]:
            # This saves off a dict (json) with the intransitivity and complexity for each game    
            metrics = self.pta_util.get_blotto_game_metrics(games_path) 
            save_path = data_path + "metrics.json"
            with open(save_path, 'w', encoding='utf-8') as f:
                json.dump(metrics, f, ensure_ascii=False, indent=4)
        
        else:
            # This assume metrics already exist
            metrics = json.load(open(data_path + "metrics.json"))
        
        assert len(metrics) > 0, "No data in metrics file!" 

        if vars["generate_figure1"]: 
            metric_data_df = pd.DataFrame.from_dict(metrics, orient="index")

            metric_data_df = metric_data_df.astype({"N":"int32"})
            metric_data_df = metric_data_df.sort_values(by="N")
            tmp_dir = data_path + "tmp/"
            if not os.path.isdir(tmp_dir):
                os.mkdir(tmp_dir)

            min_x = int(metric_data_df[["N"]].min())
            max_x = int(metric_data_df[["N"]].max())
            hue = metric_data_df[["K", "payout"]].apply(tuple, axis=1)
            plt.figure(figsize=(16,10))
            hue_order = ["[1, 1, 1]", "[1, 2, 4]", "[2, 3, 4]", "[1, 1, 1, 1]", "[1, 2, 3, 4]"]
            ax = sns.relplot(data=metric_data_df, x="N", y="intrans", hue="payout", kind="line", hue_order=hue_order)
            ax.set(ylabel="Fc / Ft")
            #plt.title("Fc / Ft vs N")
            #plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
            #plt.xticks(np.arange(min_x, max_x+1, 1.0))
            plt.savefig(tmp_dir + "intrans.png")
            plt.close()

            plt.figure(figsize=(16,10))
            #ax = sns.scatterplot(data=metric_data_df, x="N", y="low_rank", hue="K", palette=["red","blue"])
            ax = sns.relplot(data=metric_data_df, x="N", y="low_rank", hue="payout", kind="line", hue_order= hue_order, legend = False)
            ax.set(ylabel="complexity")
            #plt.title("Complexity vs N")
            #plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
            #plt.xticks(np.arange(min_x, max_x+1, 1.0))
            plt.savefig(tmp_dir + "complex.png")
            plt.close()

            img1 = Image.open(tmp_dir + "intrans.png") 
            img2 = Image.open(tmp_dir + "complex.png") 
            images = [img1, img2]
            plot_image_list_horizontal(images, self._output_path + "figure_1.png")

    def run_experiment2(self, vars: dict) -> None:
        """
        This generates the first blotto example. Not shown in the 9 page version of the paper
        """
        output_path = vars["data_path"] + "experiment2/"
        tmp_dir = output_path + "tmp/"
        Path(tmp_dir).mkdir(parents=True, exist_ok=True) 

        blotto_params = vars["blotto_params"]

        if vars["create_blotto_data"]:
            print("Creating Data For Experiment 2")
            from principal_tradeoff_analysis.blotto import create_blotto_data 
            K = blotto_params["K"]
            N = blotto_params["N"]
            payoff = blotto_params["payoff"]
            seed = blotto_params["seed"]
            game_path = output_path + f"blotto_{K}_{N}_seed={seed}_payoff={payoff}/"
            Path(game_path).mkdir(parents=True, exist_ok=True)
            A, P, X = create_blotto_data(K=K ,N=N, seed=seed, payoff=payoff, game_type="weighted")
            F = P - 0.5
            np.save(game_path + "F.npy", F)
            np.save(game_path + "P.npy", P)
            np.save(game_path + "X.npy", X)

        game_path = vars["game_path"]
        F = np.load(game_path + "F.npy")
        X = np.load(game_path + "X.npy")
        n = blotto_params["D"]
        
        disc_games, eigs, _ , _ = self.pta_util.get_disc_games(F, n = n)

        if vars["generate_figure2"]:
            print("Generating figure 2")
            disc_game_indexes = [i for i in range(n)]

            high_alloc = [(X[:,0] == 45) , (X[:,1] == 45) , (X[:,2] == 45)]

            mid_alloc1 = (X[:,0] == 20) & (X[:,1] == 20)
            mid_alloc2 = (X[:,0] == 20) & (X[:,2] == 20)
            mid_alloc3 = (X[:,1] == 20) & (X[:,2] == 20)
            
            high_low_low1 = (X[:,0] == 30) & (X[:,1] == 10) & (X[:,2] == 5)
            high_low_low2 = (X[:,0] == 30) & (X[:,2] == 10) & (X[:,1] == 5)
            high_low_low3 = (X[:,1] == 30) & (X[:,0] == 10) & (X[:,2] == 5)
            high_low_low4 = (X[:,1] == 30) & (X[:,2] == 10) & (X[:,0] == 5)
            high_low_low5 = (X[:,2] == 30) & (X[:,0] == 10) & (X[:,1] == 5)
            high_low_low6 = (X[:,2] == 30) & (X[:,1] == 10) & (X[:,0] == 5)

            mid_high_alloc1 = (X[:,0] == 29) & (X[:,1] == 16)
            mid_high_alloc2 = (X[:,0] == 29) & (X[:,2] == 16)
            mid_high_alloc3 = (X[:,1] == 29) & (X[:,0] == 16)
            mid_high_alloc4 = (X[:,1] == 29) & (X[:,2] == 16)
            mid_high_alloc5 = (X[:,2] == 29) & (X[:,0] == 16)
            mid_high_alloc6 = (X[:,2] == 29) & (X[:,1] == 16)
            uniform = (X[:,0] == 15) & (X[:,1] == 15) & (X[:,2] == 15)
   
            conditions = {}

            # Comment out the conditions dictionary below to display particular "representative" points on the disc game
            # Not used in paper
            """                      
            conditions = {
                "high": [high_alloc ,"red", "v"],
                "mid": [[mid_alloc1 , mid_alloc2 , mid_alloc3], "blue", "^"],
                #"high_low_low": [
                #    [high_low_low1,
                #    high_low_low2,
                #    high_low_low3,
                #    high_low_low4,
                #    high_low_low5,
                #    high_low_low6
                #    ], "purple",
                #    ["<"]
                #],
                "mid_high": [
                    [mid_high_alloc1 , 
                    mid_high_alloc2 , 
                    mid_high_alloc3 , 
                    mid_high_alloc4 , 
                    mid_high_alloc5 , 
                    mid_high_alloc6 ], "green", ">"
                    #["[30 15 0]", "[30 0 15]", "[15 30 0]", "[0 30 15]", "[15 0 30]", "[0 15 30]"]
                ],
                "uniform": [[uniform], "yellow", ","]
                
                
            }
            """
           
            self.create_full_blotto_plot(n = n, tmp_dir=tmp_dir, F = F, X = X, disc_game_indexes=disc_game_indexes,  figure_name = f"figure_2.png", 
                                         conditions = conditions, plot_phase=True, plot_radius=False, disc_game_as_row=True)                            
        
    def run_experiment3(self, vars: dict) -> None:
        """
        This generates the second blotto example. Not shown in the 9 page version of the paper
        """
        output_path = vars["data_path"] + "experiment3/"
        tmp_dir = output_path + "tmp/"
        game_dir = vars["game_path"]
        Path(tmp_dir).mkdir(parents=True, exist_ok=True)

        blotto_params = vars["blotto_params"]
        if vars["create_blotto_data"]:
            print("Creating Data For experiment 3")
            from principal_tradeoff_analysis.blotto import create_blotto_data 
            K = blotto_params["K"]
            N = blotto_params["N"]
            payoff = blotto_params["payoff"]
            seed = blotto_params["seed"]
            game_path = output_path + f"blotto_{K}_{N}_seed={seed}_payoff={payoff}/"
            Path(game_path).mkdir(parents=True, exist_ok=True)
            _, P, X = create_blotto_data(K=K ,N=N, seed=seed, payoff=payoff, game_type="weighted")
            F = P - 0.5
            np.save(game_path + "F.npy", F)
            np.save(game_path + "P.npy", P)
            np.save(game_path + "X.npy", X)
        
        if vars["generate_figure3"]:
            print("Creating Figure 3")
            game_path = vars["game_path"]
            F = np.load(game_path + "F.npy")
            X = np.load(game_path + "X.npy")
            n = blotto_params["D"]
            self.create_full_blotto_plot(n = n, tmp_dir=tmp_dir, F = F, X = X, plot_phase=True, plot_radius=False,figure_name = "figure_3.png") 
        
    def run_experiment4(self, vars: dict) -> None:

        output_path = vars["data_path"] + "experiment4/"
        tmp_dir = output_path + "tmp/"
        game_dir = vars["game_path"]
        Path(tmp_dir).mkdir(parents=True, exist_ok=True)
        game_path = vars["game_path"]
        F = np.load(game_path + "F.npy")
        X = np.load(game_path + "X.npy")


        # Not used but listed here for reference and also if 
        # you want to run the plot_phase_vs_traits method
        traits = ['Type 1',
                'Type 2',
                'HP',
                    'Attack',
                    'Defense',
                    'Sp. Atk',
                    'Sp. Def',
                    'Speed',
                    'Generation']
        
        if vars["generate_all_disc_games"]:
            n = 5
            X_sub = X
            disc_game_indexes = [0, 1, 3]
            #self.create_full_blotto_plot(n = n, tmp_dir=tmp_dir, F = F, X = X_sub, plot_phase=False,figure_name = "figure_4.png", disc_game_indexes= disc_game_indexes)  
            self.create_pokemon_discgame_plots(n = n, tmp_dir=tmp_dir, F = F, X = X_sub,figure_name = "figure_4.png", disc_game_indexes= disc_game_indexes)  
        
        if vars["generate_attack_matrix"]:
            self.pta_util.generate_pokemon_3panel(A = F, X = X, n = 4, cluster_n= 1, exp_dir= output_path)
    