#this file will analyze saved archives and create graphs.

import numpy as np
import pandas
from ribs.archives import ArchiveDataFrame, CVTArchive, GridArchive
from ribs.visualize import cvt_archive_heatmap, grid_archive_heatmap
import matplotlib.pyplot as plt
from tqdm import trange, tqdm
import argparse
import os

import seaborn as sns

POPSIZE = 1000
ITERS = 1000
REPEATS = 5
save_interval = 10
LEVELS = [1, 4]
METHODS = ["lex", "lex_space", "me"]

qd_score_offset = -10
LINKS = 8
link_lengths = [1 for i in range(LINKS)]


file_name = "arm_data/8"

scores = np.zeros((ITERS//save_interval, REPEATS, len(METHODS), len(LEVELS)))
mean_scores = np.zeros((ITERS//save_interval, REPEATS, len(METHODS), len(LEVELS)))
qd_scores = np.zeros((ITERS//save_interval, REPEATS, len(METHODS), len(LEVELS)))
coverage = np.zeros((ITERS//save_interval, REPEATS, len(METHODS), len(LEVELS)))


for ITER in trange(0, ITERS//save_interval):
    for method_i, method in enumerate(METHODS):
        for level_i, level in enumerate(LEVELS):
            if method == "me" and level != 1:
                continue
            for repeat in range(REPEATS):
                dir_name = f"{file_name}/{str(ITERS)}_{str(POPSIZE)}/{method}"

                result_archive = CVTArchive(
                    solution_dim = LINKS,
                    cells=10000,
                    qd_score_offset=qd_score_offset, #offset for calcualting QD SCORE
                    ranges=[(-np.sum(link_lengths), np.sum(link_lengths)), (-np.sum(link_lengths), np.sum(link_lengths))],
                    use_kd_tree=True
                )

                adf = ArchiveDataFrame(pandas.read_csv(f"{dir_name}/spread_{level}/archive_{repeat}_{ITER}.csv"))

                #get the best score of the archive
                sols = adf.solution_batch()
                obj = adf.objective_batch()
                meas = adf.measures_batch()

                #insert all into an archive
                result_archive.add(adf.solution_batch(), adf.objective_batch(), adf.measures_batch())

                plt.figure(figsize=(8, 6))
                cvt_archive_heatmap(result_archive, plot_centroids=False, lw=0.1, vmax=0, vmin=-2)
                plt.title(f"Arm Final, Pop {POPSIZE}, type: {method}, level: {level}")

                spread_folder = f"{file_name}/spreads_{str(ITERS)}_{str(POPSIZE)}/{str(level)}/{method}/{str(repeat)}"
                
                if not os.path.exists(spread_folder):
                    os.makedirs(spread_folder,  exist_ok = True)
            
                plt.savefig(f"{spread_folder}.png")
                plt.clf()
                plt.close()
                
                scores[ITER][repeat][method_i][level_i] = np.max(obj)
                mean_scores[ITER][repeat][method_i][level_i] = np.mean(obj)

                qd_score = result_archive.stats.qd_score
                qd_scores[ITER][repeat][method_i][level_i] = qd_score
                coverage[ITER][repeat][method_i][level_i] = result_archive.stats.coverage

#save the data for graphing later
np.save(f"./{file_name}/all_scores_{ITERS}_{POPSIZE}.npy", scores)
np.save(f"./{file_name}/all_mean_scores_{ITERS}_{POPSIZE}.npy", mean_scores)
np.save(f"./{file_name}/all_qd_scores_{ITERS}_{POPSIZE}.npy", qd_scores)
np.save(f"./{file_name}/all_coverage_{ITERS}_{POPSIZE}.npy", coverage)