#this file will analyze saved archives and create graphs.

import numpy as np
import pandas
from ribs.archives import ArchiveDataFrame, CVTArchive, GridArchive
from ribs.visualize import cvt_archive_heatmap, grid_archive_heatmap
import matplotlib.pyplot as plt
from tqdm import trange, tqdm
import argparse
import os

import seaborn as sns

parser = argparse.ArgumentParser()
parser.add_argument("--popsize", default=500, required=True, type=int)
parser.add_argument("--iters", default=1000, required=True, type=int)

args = parser.parse_args()

POPSIZE = args.popsize
ITERS = args.iters
REPEATS = 10
save_interval = 10
LEVELS = [1, 4, 16, 64]
METHODS = ["lex_time", "lex_space", "me"]

qd_score_offset = 0

file_name = "knight_data"

scores = np.zeros((ITERS//save_interval, REPEATS, len(METHODS), len(LEVELS)))
mean_scores = np.zeros((ITERS//save_interval, REPEATS, len(METHODS), len(LEVELS)))
qd_scores = np.zeros((ITERS//save_interval, REPEATS, len(METHODS), len(LEVELS)))
coverage = np.zeros((ITERS//save_interval, REPEATS, len(METHODS), len(LEVELS)))


for ITER in trange(0, ITERS//save_interval):
    for method_i, method in enumerate(METHODS):
        for level_i, level in enumerate(LEVELS):
            if method == "me" and level != 1:
                continue
            for repeat in range(REPEATS):
                dir_name = f"{file_name}/{str(ITERS)}_{str(POPSIZE)}/{method}"

                result_archive = GridArchive(
                    solution_dim = 66,
                    dims = [8, 8],
                    ranges = [(0, 8), (0, 8)],
                    qd_score_offset = qd_score_offset,
                )

                adf = ArchiveDataFrame(pandas.read_csv(f"{dir_name}/spread_{level}/archive_{repeat}_{ITER}.csv"))

                #get the best score of the archive
                sols = adf.solution_batch()
                obj = adf.objective_batch()
                meas = adf.measures_batch()

                #insert all into an archive
                result_archive.add(adf.solution_batch(), adf.objective_batch(), adf.measures_batch())

                plt.figure(figsize=(8, 6))
                grid_archive_heatmap(result_archive)
                plt.title(f"Kheperax Final, Pop {POPSIZE}, type: {method}, level: {level}")

                spread_folder = f"{file_name}/spreads_{str(ITERS)}_{str(POPSIZE)}/{str(level)}/{method}/{str(repeat)}"
                
                if not os.path.exists(spread_folder):
                    os.makedirs(spread_folder,  exist_ok = True)
            
                plt.savefig(f"{spread_folder}.png")
                plt.clf()
                plt.close()
                
                scores[ITER][repeat][method_i][level_i] = np.max(obj)
                mean_scores[ITER][repeat][method_i][level_i] = np.mean(obj)

                qd_score = result_archive.stats.qd_score
                qd_scores[ITER][repeat][method_i][level_i] = qd_score
                coverage[ITER][repeat][method_i][level_i] = result_archive.stats.coverage

#save the data for graphing later
np.save(f"./{file_name}/all_scores_{ITERS}_{POPSIZE}.npy", scores)
np.save(f"./{file_name}/all_mean_scores_{ITERS}_{POPSIZE}.npy", mean_scores)
np.save(f"./{file_name}/all_qd_scores_{ITERS}_{POPSIZE}.npy", qd_scores)
np.save(f"./{file_name}/all_coverage_{ITERS}_{POPSIZE}.npy", coverage)