import json
import numpy as np
import os
import time

from dataclasses import asdict
from mpi4py import MPI
from scipy.interpolate import interp1d

from core.utils.csv_utils import csv_write
from torch.utils.tensorboard import SummaryWriter


comm = MPI.COMM_WORLD
num_workers = comm.Get_size()
rank = comm.Get_rank()


class Logger:
    def __init__(self, trainer, config):
        self.start_time = time.time()
        self.trainer = trainer
        self.config = config

        # Get params
        xp_params = config.xp_params
        self.dir = xp_params.dir
        self.save_models = xp_params.save_models
        self.save_models_every = xp_params.save_models_every
        self.log_freq = xp_params.log_freq
        self.evolution_strategy_name = config.evolution_strategy_name

        # Create directories
        agent_name = config.agent_name + "_" + config.name
        if rank == 0:
            self._create_dirs(config.env_name, config.qd_strategy_name, agent_name, config.xp_params.seed, xp_params.save_models)
            # Dump config file
            with open(self.agent_dir + "/" + "config.json", "w") as cfg:
                json.dump(asdict(config), cfg)

        if rank == 0:
            config_str = json.dumps(asdict(config))
            self.writer.add_text('config', config_str, 0)

        # Path to save actor weights
        env_dir = self.dir + "/out/" + config.env_name
        qd_strategy_dir = env_dir + "/" + config.qd_strategy_name.lower()
        agent_dir = qd_strategy_dir + "/" + agent_name.lower()
        agent_dir = agent_dir + "/" + "run_{}".format(config.xp_params.seed)
        self.save_models_path = agent_dir + "/models"
        self.agent_dir = agent_dir + "/"

        self.list_log_behavior = []

        self.max_fitnesses = []

        if rank == 0:
            self.print_infos(config)

    def print_infos(self, config):
        print("---------------------------------")
        print("Env: {}".format(config.env_name), flush=True)
        print("Agent: {}".format(config.agent_name), flush=True)
        print("Selection strategy: {}".format(config.evolution_strategy_name), flush=True)
        print("Novelty method: {}".format(config.novelty_method_params.name), flush=True)
        print("Behavior extraction method: {}".format(config.behavior_extraction_method), flush=True)
        print("Nb workers: {}".format(config.xp_params.num_workers), flush=True)
        print("Novelty size: {}".format(config.xp_params.novelty_size), flush=True)
        print("Seed: {}".format(config.xp_params.seed), flush=True)
        print("---------------------------------")
        print("lr: {}".format(config.qd_strategy_params.quality_lr), flush=True)
        print("layers_dim: {}".format(config.rl_agent_params.layers_dim), flush=True)
        print("gradient_steps_ratio: {}".format(config.qd_strategy_params.gradient_steps_ratio), flush=True)
        print("---------------------------------")

    def _create_dirs(self, env_name, qd_strategy_name, agent_name, seed, save_models):
        # create directories
        if not os.path.exists(self.dir + "/out"):
            os.makedirs(self.dir + "/out")

        env_dir = self.dir + "/out/" + env_name
        if not os.path.exists(env_dir):
            os.makedirs(env_dir)

        qd_strategy_dir = env_dir + "/" + qd_strategy_name.lower()
        if not os.path.exists(qd_strategy_dir):
            os.makedirs(qd_strategy_dir)

        agent_dir = qd_strategy_dir + "/" + agent_name.lower()
        if not os.path.exists(agent_dir):
            os.makedirs(agent_dir)

        self.agent_dir = agent_dir + "/" + "run_{}".format(seed)
        os.makedirs(self.agent_dir)

        self.writer = SummaryWriter(self.agent_dir)

        if save_models and not os.path.exists(self.agent_dir + "/models"):
            os.makedirs(self.agent_dir + "/models")

    def _map_elites_log(self, steps):
        # log behavior, gen, steps in one array
        log_gen_steps_behaviors_fitnesses_from_novelty = self.trainer.qd_strategy.log_behavior()
        self.list_log_behavior.append(log_gen_steps_behaviors_fitnesses_from_novelty)
        np.save(self.agent_dir + "behaviors", np.vstack(self.list_log_behavior))
        self.writer.add_scalar("map_elites_fitness_evaluation", np.mean(self.trainer.qd_strategy.map_elites_eval_fitness), global_step=steps)
        if self.config.archive_params.archive == "map_elites":
            percentage_filled_grid = sum(x is not None for x in self.trainer.qd_strategy.archive.cells)/len(self.trainer.qd_strategy.archive.cells)
            self.writer.add_scalar("map_elites_percentage_filled", percentage_filled_grid, global_step=steps)
        elif self.config.archive_params.archive == "deep_map_elites":
            percentage_filled_grid = sum(sub_list != [] for sub_list in self.trainer.qd_strategy.archive.cells) / len(self.trainer.qd_strategy.archive.cells)
            self.writer.add_scalar("map_elites_percentage_filled", percentage_filled_grid, global_step=steps)

        # Behavior space coverage
        percentage_filled_grid = sum(x is not None for x in self.trainer.qd_strategy.archive.cells) / len(self.trainer.qd_strategy.archive.cells)
        self.writer.add_scalar("grid_percentage_filled", percentage_filled_grid, global_step=steps)
        nb_cells_per_dimension = self.trainer.config.map_elites_params.nb_cells_per_dimension
        img = np.zeros((3, nb_cells_per_dimension, nb_cells_per_dimension))
        for i in range(nb_cells_per_dimension):
            for j in range(nb_cells_per_dimension):
                if self.trainer.qd_strategy.archive.cells[i*nb_cells_per_dimension+j] is not None:
                    img[0][i][j] = 1
        self.writer.add_image('behavior_space_coverage', img, steps)


        # Save grid params
        if self.trainer.qd_strategy.save_model_steps >= self.save_models_every and self.save_models:
            cells_without_none = [x for x in self.trainer.qd_strategy.archive.cells]
            grid_params = [self.trainer.qd_strategy.archive.container["params"][ind] if ind is not None else None for ind in cells_without_none]
            grid_params = np.array(grid_params)
            np.save(self.save_models_path+"/grid_params_steps_{}".format(self.trainer.qd_strategy.total_steps), grid_params)
            self.trainer.qd_strategy.save_model_steps = 0


    def log(self):
        if rank == 0 and self.trainer.qd_strategy.log_steps >= self.log_freq:
            # Get various logs
            steps = self.trainer.qd_strategy.total_steps                                        # Steps

            if "mapelites".lower() in self.evolution_strategy_name.lower():
                self._map_elites_log(steps)
            curr_time = round(time.time()-self.start_time, 2)                                 # Time
            fitness_pop_novelty, fitness_pop_quality, fitness_pop = self.trainer.qd_strategy.log_fitness()     # Fitness
            novelty_score = self.trainer.qd_strategy.novelty_score                           # Novelty
            len_archive = len(self.trainer.qd_strategy.archive.container["fitnesses"])       # Archive

            logs = {
                "time": curr_time,
                "steps": steps,
                "fitness_pop_avg": np.mean(fitness_pop),
                "fitness_pop_max": np.max(fitness_pop),
                "fitness_pop_min": np.min(fitness_pop),
                "fitness_pop_quality_avg": np.mean(fitness_pop_quality),
                "fitness_pop_novelty_avg": np.mean(fitness_pop_novelty),
                "novelty_score_avg": np.mean(novelty_score),
                "novelty_score_min": np.min(novelty_score),
                "novelty_score_max": np.max(novelty_score),
                "archive_length": len_archive,
                "archive_percentage_admission": len_archive/(num_workers*self.trainer.qd_strategy.gen),
                "generation_duration": self.trainer.generation_duration
            }
            # Paper metric
            self.max_fitnesses.append(logs["fitness_pop_max"])
            self.writer.add_scalar("fitness/max_fitness_over_whole_run", np.max(self.max_fitnesses), steps)

            # Tensorboard
            self.writer.add_scalar("fitness/fitness_pop_avg", logs["fitness_pop_avg"], steps)                       # Fitness
            self.writer.add_scalar("fitness/fitness_pop_min", logs["fitness_pop_min"], steps)
            self.writer.add_scalar("fitness/fitness_pop_max", logs["fitness_pop_max"], steps)
            self.writer.add_scalar("fitness/fitness_pop_quality_avg", logs["fitness_pop_quality_avg"], steps)
            self.writer.add_scalar("fitness/fitness_pop_quality_avg", logs["fitness_pop_novelty_avg"], steps)
            self.writer.add_scalar("novelty/novelty_score_avg", logs["novelty_score_avg"], steps)                   # Novelty
            self.writer.add_scalar("novelty/novelty_score_min", logs["novelty_score_min"], steps)
            self.writer.add_scalar("novelty/novelty_score_max", logs["novelty_score_max"], steps)
            self.writer.add_scalar("archive/archive_length", logs["archive_length"], steps)                         # Archive
            self.writer.add_scalar("archive/archive_percentage_admission", logs["archive_percentage_admission"], steps)
            self.writer.add_scalar("buffer/length_buffer", len(self.trainer.qd_strategy.buffer.storage), steps)     # Buffer
            self.writer.add_scalar("computation/generation_duration", logs["generation_duration"], steps)
            if self.config.novelty_method_params.name == "euclidian_distance":
                self.writer.add_scalar("novelty_archive/percentage_added_behaviors", self.trainer.qd_strategy.novelty_method.percentage_added_behaviors, steps)
                self.writer.add_scalar("novelty_archive/length", len(self.trainer.qd_strategy.novelty_method.container["behaviors"]), steps)

            # CSV
            csv_fields = ["time", "steps", "fitness_pop_avg", "fitness_pop_max", "fitness_pop_quality_avg",
                      "fitness_pop_quality_avg", "novelty_score_avg", "archive_length", "archive_percentage_admission"]
            csv_write(path_to_file=self.agent_dir+"/logs.csv", logs_dict=logs, fields=csv_fields)

            # Reset the counter
            self.trainer.qd_strategy.log_steps = 0

    def save_population_params(self):
        if self.trainer.qd_strategy.save_model_steps >= self.save_models_every and self.save_models:
            file_name = "agent_{}_steps_{}".format(rank, self.trainer.qd_strategy.total_steps)
            self.trainer.qd_strategy.agent.save(file_name, self.save_models_path)
            self.trainer.qd_strategy.save_model_steps = 0
