import os
import shutil
import subprocess
import copy
import tempfile
import numpy as np
import torch
import torch_geometric
from optimal_agents.morphology import Morphology
from optimal_agents.algs.ea_base import EvoAlg, Individual
from optimal_agents.utils.loader import get_env, load_from_name
from optimal_agents.policies.pruning_models import NodeMorphologyVF

import random

class BasicEA(EvoAlg):

    def __init__(self, *args, num_cores=4, cpus_per_ind=1, save_freq=1, nge_mutation=False, mutate_structure_freq=1, 
                    log_smoothing=False, eval_ep=10, **kwargs):
        super(BasicEA, self).__init__(*args, **kwargs)
        # Save the extra args for basic EA
        self.num_cores = num_cores
        self.mutate_structure = True
        self.mutate_structure_freq = mutate_structure_freq
        self.cpus_per_ind = cpus_per_ind
        self.save_freq = save_freq # TODO: Currently doing nothing with save_freq
        self.nge_mutation = nge_mutation
        self.eval_ep = eval_ep
        self.log_smoothing = log_smoothing

    def _mutate(self, individual):
        if self.nge_mutation:
            return Individual(individual.morphology.mutate_nge(**self.mutation_kwargs, mutate_limbs=self.mutate_structure))
        else:
            return Individual(individual.morphology.mutate(**self.mutation_kwargs, mutate_limbs=self.mutate_structure))
    
    def _train_policies(self, gen_idx):
        batch_size = self.num_cores // self.cpus_per_ind
        cpu_assignment_starts = [self.cpus_per_ind * k for k in range(batch_size)]
        run_path = os.path.join(os.path.dirname(__file__), 'ea_subproc.py')
        # Loop population size // batch_size times.
        for i in range(0, len(self.population), batch_size):
            processes = []
            for j, individual in enumerate(self.population[i:min(i+batch_size, len(self.population))]):
                _, individual_params_path = tempfile.mkstemp(text=True, prefix='params', suffix='.json')
                individual_params = copy.copy(self.params) # Warning: not a deep copy. do not modify
                if gen_idx == 0 and not self.retrain:
                    individual_params['timesteps'] *= 2 # Train double length on gen 0
                individual_params['name'] = os.path.join("gen_" + str(gen_idx), "ind_" + str(individual.index))
                individual_params.save(individual_params_path)

                cmd_args = []
                cpus = ",".join([str(cpu) for cpu in range(cpu_assignment_starts[j], cpu_assignment_starts[j] + self.cpus_per_ind)])
                cmd_args.extend(["taskset", "-c", cpus])
                cmd_args.extend(["python", run_path, "--params", individual_params_path, "--base-path", self.path, '--eval-ep', str(self.eval_ep)])
                
                _, morphology_path = tempfile.mkstemp(text=False, prefix='morphology', suffix='.pkl')
                
                individual.morphology.save(morphology_path)
                cmd_args.extend(["--morphology", morphology_path])
                if self.log_smoothing:
                    cmd_args.append("--log-smoothing")

                if not individual.model is None and not self.retrain:
                    cmd_args.extend(["--model-path", individual.model])
                
                cmd_args.extend(self._get_pruning_cmd_args(individual))

                proc = subprocess.Popen(cmd_args)
                processes.append(proc)
            
            print("Waiting for completion of ", len(processes), "training jobs.")
            for p in processes:
                if p is not None:
                    p.wait()

        # Now pull the results and update the population by population.update
        generation_folder = os.path.join(self.path, 'gen_' + str(gen_idx))
        num_updates = 0
        for output in os.listdir(generation_folder):
            # Get the correct individual from the population and update it.
            idx = int(output.split('_')[1]) # The second item in the name gives the index.
            model_path = os.path.join(generation_folder, output)
            with open(os.path.join(model_path, 'fitness.tmp'), 'r') as f:
                fitness = float(f.read())
            self.population[idx].update(fitness, model_path)
            num_updates += 1
        assert num_updates == len(self.population), "Did not update once per individual in population"

        if gen_idx % self.mutate_structure_freq == 0:
            self.mutate_structure = True

    def _get_pruning_cmd_args(self, individual):
        return []

    def _clean(self, gen_idx):
        # Remove files that are not in correspondence with gen_idx
        # Note that we will use current gen as next gen references its policies
        gen_idx_to_remove = gen_idx - 1
        if (gen_idx_to_remove + 1) % self.save_freq == 0 or gen_idx_to_remove < 1:
            return # return if that was generation we were going to keep
        remove_gen_folder = os.path.join(self.path, 'gen_' + str(gen_idx_to_remove))
        shutil.rmtree(remove_gen_folder) # Delete the entire folder, cleaning up gb of data
