import numpy as np
import argparse
#import torch

from ribs.archives import GridArchive, CVTArchive
from ribs.visualize import grid_archive_heatmap
from discreteEmitter import discreteEmitter
from ribs.schedulers import Scheduler

import matplotlib.pyplot as plt

from knights_domain import get_hand_metrics_and_obs, get_pop_diversity, get_subaggregation, mutate
from lexicase import select_from_scores

import seaborn as sns

import sys
import os

from tqdm import trange

parser = argparse.ArgumentParser()
parser.add_argument("--level", required=True, default=64, type=int)
parser.add_argument("--popsize", default=500, required=True, type=int)
parser.add_argument("--iters", default=1000, required=True, type=int)
parser.add_argument("--repeats", default=10, required=True, type=int)

#shuffle reorders the scores randomly so they don't correspond correctly for lexicase
parser.add_argument("--shuffle", default=False, required=False, type=bool)
parser.add_argument("--name", default="", required=False)
parser.add_argument("--elitism", required=False, default=False, type=bool)
parser.add_argument("--method", required=True, default="lex", type=str)
parser.add_argument("--deagg", required=False, default="space", type=str)

args = parser.parse_args()
N = 8
POPSIZE = args.popsize
ITERS = args.iters
REPEATS = args.repeats
ELITISM = args.elitism
print(f"ELITISM {ELITISM}")
level = args.level
METHOD = args.method
deagg = args.deagg

if METHOD == "lex":
    dir_name = f"knight_data/{str(ITERS)}_{str(POPSIZE)}/{args.name}_{str(deagg)}"
else:
    dir_name = f"knight_data/{str(ITERS)}_{str(POPSIZE)}/{args.name}"

if not os.path.exists(dir_name):
    os.makedirs(dir_name,  exist_ok = True)

if not os.path.exists(f"{dir_name}/spread_{level}/"):
        os.makedirs(f"{dir_name}/spread_{level}/", exist_ok = True)

population = np.random.randint(0, 8, size=(POPSIZE, 66))

if (level not in [64, 16, 4, 1]):
    print("ERROR: LEVEL NOT VALID")
    raise Exception

result_archive = GridArchive(
    solution_dim = 66,
    dims=[8, 8],
    ranges=[(0, 8), (0, 8)]
)

if (METHOD == "ME"):
    emitters = [
        discreteEmitter(
            archive=result_archive, 
            x0 = np.random.randint(8, size=(66)),
            domain = "knights",
            batch_size = 1) for _ in range(POPSIZE)
        ]
    
    scheduler = Scheduler(result_archive, emitters)
    sols = scheduler.ask()
    newPop = sols.astype(int)

all_fits_data = []
all_meas_data = []
all_div_data = []
all_qd_scores = []
for itr in trange(0, REPEATS, desc=f"repeats", file=sys.stdout):
    all_fits = []
    all_meas = []
    all_div = []
    for iterations in trange(0, ITERS, desc=f"Iterations", file=sys.stdout):
        
        all_scores = []
        all_visiteds = []
        fits = []
        meas = []
        endxy = []
        for i in range(0, population.shape[0]):
            measures, obj, all_visited, rows, cols, diag1, diag2 = get_hand_metrics_and_obs(population[i])
            vis_reshaped = all_visited.reshape((8, 8))
            all_visiteds.append(all_visited)
            
            scores = get_subaggregation(vis_reshaped, level, deagg)
            scores = np.array(scores)

            rng = np.random.default_rng()
            if (args.shuffle):
                scores = rng.permuted(scores, axis=0)

            all_scores.append(scores)
            fits.append(obj)
            meas.append(np.sum(measures))
            endxy.append(measures[6:]) #this extracts the end x and y position, for use in the result archive

        div = get_pop_diversity(np.array(all_visiteds))
        all_scores = np.array(all_scores)
        
        result_archive.add(population, fits, endxy)
        
        newPop = []

        if METHOD == "lex":
            selected = select_from_scores(all_scores, selection="lex", elitism=ELITISM)

            if ELITISM:
                newPop.append(population[selected[0]])
            for i in selected[int(ELITISM):]: #don't mutate elitism
                ind = population[i]
                mut_ind = mutate(ind)
                newPop.append(mut_ind)

        elif METHOD == "ME":
            scheduler.tell(fits, endxy)
            sols = scheduler.ask()
            newPop = sols.astype(int)

        all_fits.append(fits)
        all_meas.append(meas)
        all_div.append(div)
        population = np.array(newPop)
        
        graph_interval = 500
        save_interval = 10
        if iterations % graph_interval == graph_interval-1:
            #graph the result archive
            plt.figure(figsize=(8, 6))
            grid_archive_heatmap(result_archive)
            plt.title(f"Knight's Archive, Iteration {iterations}, Pop {POPSIZE}, type: {METHOD}, level: {level}")
            plt.savefig(f"{dir_name}/spread_{level}/{itr}_{iterations//graph_interval}.png")
            plt.clf()
            plt.close()

        if iterations % save_interval == save_interval-1:
            #save the archive to a file
            print(f"WE ARE SAVING: i={iterations}\n\n\n save_interval={save_interval}")
            print(f"NAME WILL BE {iterations//save_interval}")
            adf = result_archive.as_pandas(include_solutions=True)
            adf.to_csv(f"{dir_name}/spread_{level}/archive_{itr}_{iterations//save_interval}.csv")

            #save the population to a file
            np.save(f"{dir_name}/spread_{level}/population_{itr}_{iterations//save_interval}.npy", population)
    
    all_fits = np.array(all_fits)
    all_meas = np.array(all_meas)
    all_fits_data.append(all_fits)
    all_meas_data.append(all_meas)
    all_div_data.append(all_div)

    #graph the result archive
    plt.figure(figsize=(8, 6))
    grid_archive_heatmap(result_archive)
    plt.title(f"Knight Archive, FINAL, Pop {POPSIZE}, type: {METHOD}, level: {level}")
    plt.savefig(f"{dir_name}/spread_{level}/{itr}_final.png")
    plt.clf()

    #get QD score == sum of all scores in the archive
    qd_score = result_archive.stats.qd_score
    all_qd_scores.append(qd_score)

    #make a new archive
    result_archive = GridArchive(
        solution_dim = 66,
        dims=[8, 8],
        ranges=[(0, 8), (0, 8)]
    )

    population = np.random.randint(0, 8, size=(POPSIZE, 66))

    if (METHOD == "ME"):
        emitters = [
            discreteEmitter(
                archive=result_archive, 
                x0 = np.random.randint(8, size=(66)),
                domain="knights",
                batch_size = 1) for _ in range(POPSIZE)
            ]
        
        scheduler = Scheduler(result_archive, emitters)
        sols = scheduler.ask()
        population = sols #initialize population

all_fits_data = np.array(all_fits_data)
all_div_data = np.array(all_div_data)
all_meas_data = np.array(all_meas_data)
print(all_qd_scores)
all_qd_scores = np.array(all_qd_scores)

PLOT_MEASURES = False

if PLOT_MEASURES:
    plt.plot(np.mean(np.max(all_meas_data, axis=2), axis=0), label="Avg Sum of measures")
    plt.plot(np.min(np.max(all_meas_data, axis=2), axis=0), label="Worst Sum of measures")
    plt.plot(np.max(np.max(all_meas_data, axis=2), axis=0), label="Best Sum of measures")

np.save(f"{dir_name}/all_fits_{level}.npy", all_fits_data)
np.save(f"{dir_name}/all_div_{level}.npy", all_div_data)
np.save(f"{dir_name}/all_qd_{level}.npy", all_qd_scores)


best = np.max(all_fits_data, axis=2)
plt.plot(np.mean(best, axis=0), label="Tiles Covered")
plt.fill_between(range(ITERS), np.percentile(best, 5, axis=0), np.percentile(best, 95, axis=0), alpha=0.5)
#sns.plot(np.min(np.max(all_fits_data, axis=2), axis=0), label="Worst tiles Covered")
#sns.plot(np.max(np.max(all_fits_data, axis=2), axis=0), label="Best tiles Covered")
plt.legend()
plt.ylabel("Best True Fitness")
plt.xlabel("Generation")
plt.title(f"""{METHOD}
          performance of best elite each generation,
          across many runs""")
plt.savefig(f"{dir_name}/all_fits_{level}.png")
plt.clf()


