import numpy as np
import argparse
import jax
#import torch

from ribs.archives import GridArchive, CVTArchive
from discreteEmitter import discreteEmitter
from ribs.emitters import GaussianEmitter
from ribs.visualize import grid_archive_heatmap, cvt_archive_heatmap
from discreteEmitter import discreteEmitter
from ribs.schedulers import Scheduler

import matplotlib.pyplot as plt
from lexicase import select_from_scores
from kheperax_domain import evaluate

import seaborn as sns

import sys
import os

from tqdm import trange

parser = argparse.ArgumentParser()
parser.add_argument("--level", required=True, default=64, type=int)
parser.add_argument("--popsize", default=500, required=True, type=int)
parser.add_argument("--iters", default=1000, required=True, type=int)
parser.add_argument("--repeats", default=10, required=True, type=int)

parser.add_argument("--name", default="lex", required=False)
parser.add_argument("--elitism", required=False, default=False, type=bool)
parser.add_argument("--method", required=True, default="lex", type=str)
parser.add_argument("--deagg", required=False, default="space", type=str)
parser.add_argument("--domain", required=False, default="decep", type=str)

args = parser.parse_args()
N = 8
POPSIZE = args.popsize
ITERS = args.iters
REPEATS = args.repeats
ELITISM = args.elitism
level = args.level
METHOD = args.method
deagg = args.deagg
domain = args.domain

if METHOD not in ["lex", "ME"]:
    print("ERROR: METHOD NOT VALID")
    raise Exception

if METHOD == "lex":
    dir_name = f"kheperax_{domain}_data/{str(ITERS)}_{str(POPSIZE)}/{args.name}_{str(deagg)}"
else:
    dir_name = f"kheperax_{domain}_data/{str(ITERS)}_{str(POPSIZE)}/{args.name}"

if not os.path.exists(dir_name):
    os.makedirs(dir_name,  exist_ok = True)

if not os.path.exists(f"{dir_name}/spread_{level}/"):
        os.makedirs(f"{dir_name}/spread_{level}/", exist_ok = True)

population = np.random.rand(POPSIZE, 66)

result_archive = GridArchive(
    solution_dim = 66,
    dims = [50, 50],
    ranges = [(0, 1), (0, 1)]
)

if (METHOD == "ME"):
    emitters = [
        GaussianEmitter(
            archive=result_archive, 
            sigma = 0.1,
            x0 = np.random.rand(66),
            batch_size = 1,) for _ in range(POPSIZE)]
    
    scheduler = Scheduler(result_archive, emitters)
    sols = scheduler.ask()
    population = sols

#craete random key jax
random_key = jax.random.PRNGKey(42)
random_key, subkey = jax.random.split(random_key)

all_fits_data = []
all_meas_data = []
all_div_data = []
all_qd_scores = []
for itr in trange(0, REPEATS, desc=f"repeats", file=sys.stdout):
    all_fits = []
    all_meas = []
    all_div = []
    for i in trange(0, ITERS, desc=f"Iterations", file=sys.stdout):
        #print(f"ITERATION {i}, pop: {population.shape}")
        random_key, subkey = jax.random.split(random_key)
        objs, obj, measures, info = evaluate(population, subkey, n=level, deaggregation=deagg, domain_type=domain)

        div = 0
        all_scores = np.array(objs)
        
        result_archive.add(population, obj, measures)
        
        newPop = []

        if METHOD == "lex":
            selected = select_from_scores(all_scores, selection="lex", elitism=ELITISM)

            noise = np.random.normal(loc = np.zeros_like(population), scale=0.1, size=population.shape)

            if ELITISM:
                newPop.append(population[selected[0]])
                noise[0] = np.zeros_like(population[0]) #no mutation on elite

            newPop = population[selected] + noise

        elif METHOD == "ME":
            
            scheduler.tell(obj, measures)
            sols = scheduler.ask()
            newPop = sols

        #print(f"obj: {obj[:3]}")
        #print(f"objs sum: {np.sum(objs, axis=1)[:3]}")

        #all_fits.append(np.sum(objs, axis=1))
        all_fits.append(obj)
        all_meas.append(measures)
        all_div.append(div)
        population = np.array(newPop)
        #print(p.population[0])

        graph_interval = 500
        save_interval = 10
        if i % graph_interval == graph_interval-1:
            #graph the result archive
            plt.figure(figsize=(8, 6))
            grid_archive_heatmap(result_archive)
            plt.title(f"Kheperax Archive, Iteration {i}, Pop {POPSIZE}, type: {METHOD}, level: {level}")
            plt.savefig(f"{dir_name}/spread_{level}/{itr}_{i//graph_interval}.png")
            plt.clf()
            plt.close()

        if i % save_interval == save_interval-1:
            #save the archive to a file
            adf = result_archive.as_pandas(include_solutions=True)
            adf.to_csv(f"{dir_name}/spread_{level}/archive_{itr}_{i//save_interval}.csv")

            #save the population to a file
            np.save(f"{dir_name}/spread_{level}/population_{itr}_{i//save_interval}.npy", population)

    
    all_fits = np.array(all_fits)
    all_meas = np.array(all_meas)
    all_fits_data.append(all_fits)
    all_meas_data.append(all_meas)
    all_div_data.append(all_div)

    #graph the result archive
    plt.figure(figsize=(8, 6))
    grid_archive_heatmap(result_archive)
    plt.title(f"Maze Archive, FINAL, Pop {POPSIZE}, type: {METHOD}, level: {level}")
    plt.savefig(f"{dir_name}/spread_{level}/{itr}_final.png")
    plt.clf()
    plt.close()

    #get QD score == sum of all scores in the archive
    qd_score = result_archive.stats.qd_score
    all_qd_scores.append(qd_score)

    population = np.random.rand(POPSIZE, 66)

    #make a new archive
    result_archive = GridArchive(
        solution_dim = 66,
        dims = [50, 50],
        ranges = [(0, 1), (0, 1)]
    )

    if (METHOD == "ME"):
        emitters = [
        GaussianEmitter(
            archive=result_archive, 
            sigma = 0.1,
            x0 = np.random.rand(66),
            batch_size = 1,) for _ in range(POPSIZE)]
        
        scheduler = Scheduler(result_archive, emitters)
        sols = scheduler.ask()
        population = np.array(sols)

all_fits_data = np.array(all_fits_data)
all_div_data = np.array(all_div_data)
all_meas_data = np.array(all_meas_data)
print(all_qd_scores)
all_qd_scores = np.array(all_qd_scores)


np.save(f"{dir_name}/all_fits_{level}.npy", all_fits_data)
np.save(f"{dir_name}/all_div_{level}.npy", all_div_data)
np.save(f"{dir_name}/all_qd_{level}.npy", all_qd_scores)


best = np.max(all_fits_data, axis=2)
plt.plot(np.mean(best, axis=0), label="distance to goal")
plt.fill_between(range(ITERS), np.percentile(best, 5, axis=0), np.percentile(best, 95, axis=0), alpha=0.5)
#sns.plot(np.min(np.max(all_fits_data, axis=2), axis=0), label="Worst tiles Covered")
#sns.plot(np.max(np.max(all_fits_data, axis=2), axis=0), label="Best tiles Covered")
plt.legend()
plt.ylabel("Best True Fitness")
plt.xlabel("Generation")
plt.title(f"""{METHOD}
          performance of best elite each generation,
          across many runs""")
plt.savefig(f"{dir_name}/all_fits_{level}.png")
plt.clf()


