import numpy as np
import argparse
#import torch

from ribs.archives import GridArchive, CVTArchive
from ribs.emitters import GaussianEmitter
from ribs.visualize import grid_archive_heatmap, cvt_archive_heatmap
from discreteEmitter import discreteEmitter
from ribs.schedulers import Scheduler

import matplotlib.pyplot as plt

from lexicase import select_from_scores
from arm_sol import ArmSolution
from arm_domain import get_objectives, get_qd_measures, visualize

import seaborn as sns

import sys
import os

from tqdm import trange

parser = argparse.ArgumentParser()
parser.add_argument("--level", required=True, default=64, type=int)
parser.add_argument("--popsize", default=500, required=True, type=int)
parser.add_argument("--iters", default=1000, required=True, type=int)
parser.add_argument("--repeats", default=10, required=True, type=int)

#shuffle reorders the scores randomly so they don't correspond correctly for lexicase
parser.add_argument("--shuffle", default=False, required=False, type=bool)
parser.add_argument("--name", default="lex", required=False)
parser.add_argument("--elitism", required=False, default=False, type=bool)
parser.add_argument("--method", required=True, default="lex", type=str)
parser.add_argument("--links", required=True, default=12, type=int)
parser.add_argument("--deagg", required=False, default="space", type=str)

args = parser.parse_args()
N = 8
POPSIZE = args.popsize
ITERS = args.iters
REPEATS = args.repeats
ELITISM = args.elitism
level = args.level
METHOD = args.method
LINKS = args.links
deagg = args.deagg

if (LINKS % level != 0):
    print("error, passing silently for sake of scripts")
    raise Exception("LINK LEVEL NOT VALID")

link_lengths = np.ones((LINKS))

#using popoviciu's inequality on variances:
# var is at most 1/4 (M-m)^2'
# var is at most ~10
if METHOD == "lex":
    dir_name = f"arm_data/{str(LINKS)}/{str(ITERS)}_{str(POPSIZE)}/{args.name}_{str(deagg)}"
else:
    dir_name = f"arm_data/{str(LINKS)}/{str(ITERS)}_{str(POPSIZE)}/{args.name}"

if not os.path.exists(dir_name):
    os.makedirs(dir_name,  exist_ok = True)

if not os.path.exists(f"{dir_name}/spread_{level}/"):
        os.makedirs(f"{dir_name}/spread_{level}/", exist_ok = True)

population = np.random.rand(POPSIZE, LINKS)

result_archive = CVTArchive(
    solution_dim = LINKS,
    cells=10000,
    qd_score_offset=-10, #offset for calcualting QD SCORE
    ranges=[(-np.sum(link_lengths), np.sum(link_lengths)), (-np.sum(link_lengths), np.sum(link_lengths))],
    use_kd_tree=True
)

if (METHOD == "ME"):
    emitters = [
        GaussianEmitter(
            archive=result_archive, 
            x0 = np.random.rand(LINKS),
            bounds=[(-np.pi, np.pi)] * LINKS,
            sigma = 0.1,
            batch_size = 1) for _ in range(POPSIZE)
        ]
    
    scheduler = Scheduler(result_archive, emitters)
    sols = scheduler.ask()
    population = sols


all_fits_data = []
all_meas_data = []
all_div_data = []
all_qd_scores = []
for itr in trange(0, REPEATS, desc=f"repeats", file=sys.stdout):
    all_fits = []
    all_meas = []
    all_div = []
    for i in trange(0, ITERS, desc=f"Iterations", file=sys.stdout):
        ##print(f"ITERATION {i}, pop: {population.shape}")
        obj, measures = get_qd_measures(population, link_lengths)
        objs = get_objectives(population, link_lengths, num_objectives=level, deagg=deagg)

        div = 0 #TODO: measure diversity?
        all_scores = np.array(objs)
        
        result_archive.add(population, obj, measures)
        
        newPop = []

        if METHOD == "lex":
            selected = select_from_scores(all_scores, selection="lex", elitism=ELITISM)

            if ELITISM:
                newPop.append(population[selected[0]])
                for i in selected[int(ELITISM):]: #don't mutate elitism
                    #gaussian mutation
                    ind = np.clip(np.random.normal(population[i], 0.1), -np.pi, np.pi)
                    newPop.append(ind)
            else:
                #this is faster with no elitism
                noise = np.random.normal(loc = np.zeros_like(population), scale=0.1, size=population.shape)
                newPop = np.clip(population[selected] + noise, -np.pi, np.pi)
                #newPop = population + noise

        elif METHOD == "ME":
            
            scheduler.tell(obj, measures)
            sols = scheduler.ask()
            newPop = sols

        #print(f"obj: {obj[:3]}")
        #print(f"objs sum: {np.sum(objs, axis=1)[:3]}")

        all_fits.append(np.sum(objs, axis=1))
        all_meas.append(measures)
        all_div.append(div)
        population = np.array(newPop)
        #print(p.population[0])

        graph_interval = 500
        save_interval = 10
        if i % graph_interval == 0:
            #graph the result archive
            plt.figure(figsize=(8, 6))
            cvt_archive_heatmap(result_archive, plot_centroids=False, lw=0.1, vmax=0, vmin=-2)
            plt.title(f"Arm Archive, Iteration {i}, Pop {POPSIZE}, type: {METHOD}, level: {level}")
            plt.savefig(f"{dir_name}/spread_{level}/{itr}_{i//graph_interval}.png")
            plt.clf()
            plt.close()

        if i % save_interval == 0:
            #save the archive to a file
            adf = result_archive.as_pandas(include_solutions=True)
            adf.to_csv(f"{dir_name}/spread_{level}/archive_{itr}_{i//save_interval}.csv")

            #save the population to a file
            np.save(f"{dir_name}/spread_{level}/population_{itr}_{i//save_interval}.npy", population)

    
    all_fits = np.array(all_fits)
    all_meas = np.array(all_meas)
    all_fits_data.append(all_fits)
    all_meas_data.append(all_meas)
    all_div_data.append(all_div)

    #graph the result archive
    plt.figure(figsize=(8, 6))
    cvt_archive_heatmap(result_archive, plot_centroids=False, lw=0.1, vmax=0, vmin=-2)
    plt.title(f"Arm Archive, FINAL, Pop {POPSIZE}, type: {METHOD}, level: {level}")
    plt.savefig(f"{dir_name}/spread_{level}/{itr}_final.png")
    plt.clf()
    plt.close()

    #get QD score == sum of all scores in the archive
    qd_score = result_archive.stats.qd_score
    all_qd_scores.append(qd_score)

    population = np.random.rand(POPSIZE, LINKS)

    #make a new archive
    result_archive = CVTArchive(
        solution_dim = LINKS,
        cells=10000,
        qd_score_offset=-10, #offset for calcualting QD SCORE
        ranges=[(-np.sum(link_lengths), np.sum(link_lengths)), (-np.sum(link_lengths), np.sum(link_lengths))],
        use_kd_tree=True
    )

    if (METHOD == "ME"):
        emitters = [
            GaussianEmitter(
                archive=result_archive, 
                x0 = np.random.rand(LINKS),
                sigma = 0.1,
                bounds=[(-np.pi, np.pi)] * LINKS,
                batch_size = 1) for _ in range(POPSIZE)
        ]
        
        scheduler = Scheduler(result_archive, emitters)
        sols = scheduler.ask()
        population = np.array(sols)

all_fits_data = np.array(all_fits_data)
all_div_data = np.array(all_div_data)
all_meas_data = np.array(all_meas_data)
print(all_qd_scores)
all_qd_scores = np.array(all_qd_scores)

PLOT_MEASURES = False

if PLOT_MEASURES:
    plt.plot(np.mean(np.max(all_meas_data, axis=2), axis=0), label="Avg Sum of measures")
    plt.plot(np.min(np.max(all_meas_data, axis=2), axis=0), label="Worst Sum of measures")
    plt.plot(np.max(np.max(all_meas_data, axis=2), axis=0), label="Best Sum of measures")

np.save(f"{dir_name}/all_fits_{level}.npy", all_fits_data)
np.save(f"{dir_name}/all_div_{level}.npy", all_div_data)
np.save(f"{dir_name}/all_qd_{level}.npy", all_qd_scores)


best = np.max(all_fits_data, axis=2)
plt.plot(np.mean(best, axis=0), label="tiles Covered")
plt.fill_between(range(ITERS), np.percentile(best, 5, axis=0), np.percentile(best, 95, axis=0), alpha=0.5)
#sns.plot(np.min(np.max(all_fits_data, axis=2), axis=0), label="Worst tiles Covered")
#sns.plot(np.max(np.max(all_fits_data, axis=2), axis=0), label="Best tiles Covered")
plt.legend()
plt.ylabel("Best True Fitness")
plt.xlabel("Generation")
plt.title(f"""{METHOD}
          performance of best elite each generation,
          across many runs""")
plt.savefig(f"{dir_name}/all_fits_{level}.png")
plt.clf()


