import numpy as np
import argparse
#import torch

from ribs.archives import GridArchive, CVTArchive
from discreteEmitter import discreteEmitter
from ribs.emitters import GaussianEmitter
from ribs.visualize import grid_archive_heatmap, cvt_archive_heatmap
from discreteEmitter import discreteEmitter
from ribs.schedulers import Scheduler

import matplotlib.pyplot as plt

from board import Board
from evaluation import get_hand_metrics_and_obs, get_pop_diversity, get_subaggregation
from lexicase import select_from_scores
from solution import Population
from maze_domain import get_objectives, get_qd_measures, mutate

import seaborn as sns

import sys
import os

from tqdm import trange

parser = argparse.ArgumentParser()
parser.add_argument("--level", required=True, default=64, type=int)
parser.add_argument("--popsize", default=500, required=True, type=int)
parser.add_argument("--iters", default=1000, required=True, type=int)
parser.add_argument("--repeats", default=10, required=True, type=int)

parser.add_argument("--name", default="lex", required=False)
parser.add_argument("--elitism", required=False, default=False, type=bool)
parser.add_argument("--method", required=True, default="lex", type=str)

args = parser.parse_args()
N = 8
POPSIZE = args.popsize
ITERS = args.iters
REPEATS = args.repeats
ELITISM = args.elitism
level = args.level
METHOD = args.method

if METHOD not in ["lex", "ME"]:
    print("ERROR: METHOD NOT VALID")
    raise Exception


dir_name = f"maze_data/{str(ITERS)}_{str(POPSIZE)}/{args.name}"

if not os.path.exists(dir_name):
    os.makedirs(dir_name,  exist_ok = True)

if not os.path.exists(f"{dir_name}/spread_{level}/"):
        os.makedirs(f"{dir_name}/spread_{level}/", exist_ok = True)

population = np.random.randint(5, size=(POPSIZE, 512))

result_archive = GridArchive(
    solution_dim = 512,
    dims = [32, 32],
    ranges = [(0, 32), (0, 32)]
)

if (METHOD == "ME"):
    emitters = [
        discreteEmitter(
            archive=result_archive, 
            x0 = np.random.randint(5, size=(512)),
            batch_size = 1,
            domain="maze",
            bounds = [(0, 5) for i in range(512)]) for _ in range(POPSIZE)
    ]
    
    scheduler = Scheduler(result_archive, emitters)
    sols = scheduler.ask()
    population = sols


all_fits_data = []
all_meas_data = []
all_div_data = []
all_qd_scores = []
for itr in trange(0, REPEATS, desc=f"repeats", file=sys.stdout):
    all_fits = []
    all_meas = []
    all_div = []
    for i in trange(0, ITERS, desc=f"Iterations", file=sys.stdout):
        #print(f"ITERATION {i}, pop: {population.shape}")
        obj, measures = get_qd_measures(population)
        objs = get_objectives(population, n=level)

        div = get_pop_diversity(np.array(population))
        all_scores = np.array(objs)
        
        result_archive.add(population, obj, measures)
        
        newPop = []

        if METHOD == "lex":
            selected = select_from_scores(all_scores, selection="lex", elitism=ELITISM)

            if ELITISM:
                newPop.append(population[selected[0]])
            for i in selected[int(ELITISM):]: #don't mutate elitism
                #gaussian mutation'
                ind = population[i]
                newPop.append(mutate(ind))
                #newPop = population + noise

        elif METHOD == "ME":
            
            scheduler.tell(obj, measures)
            sols = scheduler.ask()
            newPop = sols

        #print(f"obj: {obj[:3]}")
        #print(f"objs sum: {np.sum(objs, axis=1)[:3]}")

        all_fits.append(np.sum(objs, axis=1))
        all_meas.append(measures)
        all_div.append(div)
        population = np.array(newPop)
        #print(p.population[0])

        graph_interval = 500
        save_interval = 10
        if i % graph_interval == 0:
            #graph the result archive
            plt.figure(figsize=(8, 6))
            grid_archive_heatmap(result_archive)
            plt.title(f"Maze Archive, Iteration {i}, Pop {POPSIZE}, type: {METHOD}, level: {level}")
            plt.savefig(f"{dir_name}/spread_{level}/{itr}_{i//graph_interval}.png")
            plt.clf()
            plt.close()

        if i % save_interval == 0:
            #save the archive to a file
            adf = result_archive.as_pandas(include_solutions=True)
            adf.to_csv(f"{dir_name}/spread_{level}/archive_{itr}_{i//save_interval}.csv")

            #save the population to a file
            np.save(f"{dir_name}/spread_{level}/population_{itr}_{i//save_interval}.npy", population)

    
    all_fits = np.array(all_fits)
    all_meas = np.array(all_meas)
    all_fits_data.append(all_fits)
    all_meas_data.append(all_meas)
    all_div_data.append(all_div)

    #graph the result archive
    plt.figure(figsize=(8, 6))
    grid_archive_heatmap(result_archive)
    plt.title(f"Maze Archive, FINAL, Pop {POPSIZE}, type: {METHOD}, level: {level}")
    plt.savefig(f"{dir_name}/spread_{level}/{itr}_final.png")
    plt.clf()
    plt.close()

    #get QD score == sum of all scores in the archive
    qd_score = result_archive.stats.qd_score
    all_qd_scores.append(qd_score)

    population = np.random.randint(5, size=(POPSIZE, 512))

    #make a new archive
    result_archive = GridArchive(
    solution_dim = 512,
    dims = [32, 32],
    ranges = [(0, 32), (0, 32)]
    )

    if (METHOD == "ME"):
        emitters = [
        discreteEmitter(
            archive=result_archive, 
            x0 = np.random.randint(5, size=(512)),
            batch_size = 1,
            domain = "maze",
            bounds = [(0, 5) for i in range(512)]) for _ in range(POPSIZE)
        ]
        
        scheduler = Scheduler(result_archive, emitters)
        sols = scheduler.ask()
        population = np.array(sols)

all_fits_data = np.array(all_fits_data)
all_div_data = np.array(all_div_data)
all_meas_data = np.array(all_meas_data)
print(all_qd_scores)
all_qd_scores = np.array(all_qd_scores)


np.save(f"{dir_name}/all_fits_{level}.npy", all_fits_data)
np.save(f"{dir_name}/all_div_{level}.npy", all_div_data)
np.save(f"{dir_name}/all_qd_{level}.npy", all_qd_scores)


best = np.max(all_fits_data, axis=2)
plt.plot(np.mean(best, axis=0), label="distance to goal")
plt.fill_between(range(ITERS), np.percentile(best, 5, axis=0), np.percentile(best, 95, axis=0), alpha=0.5)
#sns.plot(np.min(np.max(all_fits_data, axis=2), axis=0), label="Worst tiles Covered")
#sns.plot(np.max(np.max(all_fits_data, axis=2), axis=0), label="Best tiles Covered")
plt.legend()
plt.ylabel("Best True Fitness")
plt.xlabel("Generation")
plt.title(f"""{METHOD}
          performance of best elite each generation,
          across many runs""")
plt.savefig(f"{dir_name}/all_fits_{level}.png")
plt.clf()


