from dataclasses import dataclass
import os
import numpy as np
import itertools
import torch
import torch_geometric

from optimal_agents.morphology import Morphology
from optimal_agents.utils.loader import get_env, get_morphology
from optimal_agents.utils.trainer import run_train
from optimal_agents.utils.tester import eval_policy
from optimal_agents.policies import random_policies
from optimal_agents.policies import predictive_models

import random

@dataclass
class Individual:
    morphology: Morphology
    fitness: float
    start_index: int
    end_index: int
    index: int

class VarEA_Base(object):

    def __init__(self, params, eval_ep=8, nge_mutation=False, save_freq=10, 
                       global_state=False, num_freqs=2, num_phases=2, sample_freq=50,
                       keep_percent=0.0, random_policy="CosinePolicy", include_end=False, include_start_state=False):

        # Save the model parameters and mutation parameters
        self.params = params
        self.mutation_kwargs = params['mutation_args']
        self.nge_mutation = nge_mutation
        self.keep_percent = keep_percent

        # Save information for data collection
        self.eval_ep = eval_ep
        self.save_freq = save_freq
        self.global_state = global_state
        self.include_start_state = include_start_state
        self.include_end = include_end

        # Construct the Random Policy
        self.max_nodes = self.mutation_kwargs['max_nodes'] if 'max_nodes' in self.mutation_kwargs else 12
        random_policy_kwargs = {'sample_freq': sample_freq, 'num_freqs' : num_freqs, 'num_phases' : num_phases}
        random_policy_cls = vars(random_policies)[random_policy]
        self.policy = random_policy_cls(self.max_nodes, **random_policy_kwargs)

        # Save data for policy training
        if ' ' in self.params['env']:
            self.eval_envs = self.params['env'].split(' ')
            self.params['env'] = self.eval_envs[0]
        else:
            self.eval_envs = [self.params['env']]

        self.policy_learning_env_wrapper = self.params['env_wrapper']
        self.policy_learning_env_wrapper_args = self.params['env_wrapper_args']
        self.params['env_wrapper'] = None
        self.params['env_wrapper_args'] = dict()

        self.buffer = []
        self.eval_buffer = []
        self.population = []

    def learn(self, path, population_size, num_generations):
        # Generate the initial population
        os.makedirs(path, exist_ok=True)

        for _ in range(population_size*2):
            self.add_morphology(get_morphology(self.params))

        for gen_idx in range(num_generations):
            self.population.sort(key=lambda individual: individual.fitness, reverse=True)

            with open(os.path.join(path, "gen" + str(gen_idx) + ".txt"), "w+") as f:
                for individual in self.population:
                    f.write(str(individual.fitness) + " " + str(individual.index) + "\n")

            if (gen_idx + 1) % self.save_freq == 0:
                gen_path = os.path.join(path, 'gen_' + str(gen_idx))
                os.makedirs(gen_path, exist_ok=True)
                self.params.save(gen_path)
                for i, individual in enumerate(self.population):
                    individual.morphology.save(os.path.join(gen_path, str(i) + '.morphology.pkl'))

            # If were at the end, exit so we don't add extra morphologies.
            if gen_idx + 1 == num_generations:
                break

            # Take population size samples from the population to construct new morphologies.
            for i in range(population_size):
                if i >= 3 and self.keep_percent > 0 and int(self.keep_percent*len(self.population)) > 5:
                    morphology = random.choice(self.population[:max(int(len(self.population)*self.keep_percent),1)]).morphology
                else:
                    morphology = self.population[i].morphology # sample a new morphology according to fitness level
                if self.nge_mutation:
                    new_morphology = morphology.mutate_nge(**self.mutation_kwargs)
                else:
                    new_morphology = morphology.mutate(**self.mutation_kwargs)

                self.add_morphology(new_morphology)

            print("Finished Gen", gen_idx)

        assert self.population[0].fitness == max([individual.fitness for individual in self.population]), "Error, didn't get max fitness individual"

        morphology = self.population[0].morphology
        env = get_env(self.params, morphology=morphology)
        env.reset()
        frames = []
        action_dim = self._get_action_dim(morphology)
        actions, _ = self.policy.step(400)
        action_idx, done = 0, False
        while not done:
            frames.append(env.render(mode='rgb_array'))
            ac = actions[action_idx, :action_dim]
            action_idx += 1
            _, _, done, _ = env.step(ac)

        import imageio
        imageio.mimsave(os.path.join(path, 'best.gif'), frames[::3], subrectangles=True, duration=0.05)
        del env

        # Now, train a policy on the given task.
        # Get the best morphology
        del self.params['env_args']['time_limit']
        if not self.params['arena'] is None and 'Terrain' in self.params['arena']:
            self.params['arena'] = None

        self.params['env_wrapper'] = self.policy_learning_env_wrapper
        self.params['env_wrapper_args'] = self.policy_learning_env_wrapper_args
        for env_name in self.eval_envs:
            self.params['env'] = env_name
            model, _ = run_train(self.params, morphology=morphology, path=path)
            env = get_env(self.params, morphology=morphology)
            avg_reward, frames = eval_policy(model, env, num_ep=1, deterministic=True, verbose=1, gif=True, render=True)
            imageio.mimsave(os.path.join(path, 'best_trained_' + env_name + '.gif'), frames[::3], subrectangles=True, duration=0.05)

    def add_morphology(self, morphology):
        eval_env = get_env(self.params, morphology=morphology)

        states = []
        action_dim = self._get_action_dim(morphology)
        num_ep = self.eval_ep
        try:
            for i in range(num_ep):
                actions, label = self.policy.step(500) # Arbitrarily set number of actions.
                done = False
                action_idx = 0
                obs = eval_env.reset()
                if self.include_start_state:
                    start_state = self._get_state(eval_env)
                while not done:
                    obs, _, done, _ = eval_env.step(actions[action_idx, :action_dim])
                    action_idx += 1
                # Get the final state and morphology as data
                if self.include_start_state:
                    states.append(np.concatenate((start_state, self._get_state(eval_env)), axis=1))
                else:
                    states.append(self._get_state(eval_env))
                if states[-1][0,2] > 8.0: # its probably flying do not include.
                    return
            fitness = np.var(states)
            individual = Individual(morphology, fitness, 0, 0, len(self.population))
            self.population.append(individual)
        except:
            return
    
    def _get_action_dim(self, morphology):
        raise NotImplementedError

class VarEA_FC(VarEA_Base):

    def _get_action_dim(self, morphology):
        return morphology.num_joints

    def _get_state(self, env):
        if self.global_state:
            if self.include_end:
                state = np.concatenate((
                    env._physics.data.xpos[-len(env._morphology)].copy(),
                    env._physics.data.site_xpos[env._morphology.end_site_indices[0]].copy() # The first end index is the root.
                ), axis=0)
            else:
                state = env._physics.data.xpos[-len(env._morphology)].copy()  
            return np.expand_dims(state.flatten(), axis=0)
        else:
            if self.include_end:
                return np.expand_dims(np.concatenate((
                        env._physics.data.xpos[-len(env._morphology):].copy(),
                        env._physics.data.site_xpos[env._morphology.end_site_indices].copy()
                    ), axis=1).flatten(), axis=0)
            else:
                return np.expand_dims(env._physics.data.xpos[-len(env._morphology):].copy().flatten(), axis=0)

class VarEA_Node(VarEA_Base):

    def _get_action_dim(self, morphology):
        return len(morphology)
    
    def _get_state(self, env):
        if self.global_state:
            if self.include_end:
                state = np.concatenate((
                    env.env._physics.data.xpos[-len(env.env._morphology)].copy(),
                    env.env._physics.data.site_xpos[env.env._morphology.end_site_indices[0]].copy() # The first end index is the root.
                ), axis=0)
            else:
                state = env.env._physics.data.xpos[-len(env.env._morphology)].copy()  
            return np.tile(np.expand_dims(state, axis=0), (len(env.env._morphology), 1))
        else:
            if self.include_end:
                return np.concatenate((
                    env.env._physics.data.xpos[-len(env.env._morphology):].copy(),
                    env.env._physics.data.site_xpos[env.env._morphology.end_site_indices].copy()
                ), axis=1) 
            else:
                return env.env._physics.data.xpos[-len(env.env._morphology):].copy()

class VarEA_Meta(VarEA_Base):

    def _get_action_dim(self, morphology):
        return morphology.num_joints

    def _get_state(self, env):
        if self.global_state:
            if self.include_end:
                state = np.concatenate((
                    env._physics.data.xpos[-len(env._morphology)].copy(),
                    env._physics.data.site_xpos[env._morphology.end_site_indices[0]].copy() # The first end index is the root.
                ), axis=0)
            else:
                state = env._physics.data.xpos[-len(env._morphology)].copy()  
            return np.tile(np.expand_dims(state, axis=0), (len(env._morphology), 1))
        else:
            if self.include_end:
                return np.concatenate((
                    env._physics.data.xpos[-len(env._morphology):].copy(),
                    env._physics.data.site_xpos[env._morphology.end_site_indices].copy()
                ), axis=1) 
            else:
                return env._physics.data.xpos[-len(env._morphology):].copy()

class VarEA_Line(VarEA_Base):

    def _get_action_dim(self, morphology):
        return morphology.num_joints

    def _get_state(self, env):
        if self.global_state:
            if self.include_end:
                state = np.concatenate((
                    env._physics.data.xpos[-len(env._morphology)].copy(),
                    env._physics.data.site_xpos[env._morphology.end_site_indices[0]].copy() # The first end index is the root.
                ), axis=0)
            else:
                state = env._physics.data.xpos[-len(env._morphology)].copy()  
            return np.tile(np.expand_dims(state, axis=0), (len(env._morphology), 1))
        else:
            if self.include_end:
                return np.concatenate((
                    env._physics.data.xpos[-len(env._morphology):].copy(),
                    env._physics.data.site_xpos[env._morphology.end_site_indices].copy()
                ), axis=1) 
            else:
                return env._physics.data.xpos[-len(env._morphology):].copy()

class VarEA_Node_Arm(VarEA_Node):
    
    def _get_state(self, env):
        if self.global_state:
            if self.include_end:
                state = np.concatenate((
                    env.env._physics.data.xpos[-1].copy(),
                    env.env._physics.data.site_xpos[env.env._morphology.end_site_indices[-1]].copy() # The Last index is the end
                ), axis=0)
            else:
                state = env.env._physics.data.xpos[-1].copy()  
            return np.tile(np.expand_dims(state, axis=0), (len(env.env._morphology), 1))
        else:
            if self.include_end:
                return np.concatenate((
                    env.env._physics.data.xpos[-len(env.env._morphology):].copy(),
                    env.env._physics.data.site_xpos[env.env._morphology.end_site_indices].copy()
                ), axis=1) 
            else:
                return env.env._physics.data.xpos[-len(env.env._morphology):].copy()


class VarEA_Line_Arm(VarEA_Line):

    def _get_state(self, env):
        if self.global_state:
            if self.include_end:
                # state = np.concatenate((
                #     env._physics.data.xpos[-1].copy(),
                #     env._physics.data.site_xpos[env._morphology.end_site_indices[-1]].copy()
                # ), axis=0)
                state = env._physics.data.site_xpos[env._morphology.end_site_indices[-1]].copy()
            else:
                state = env._physics.data.xpos[-1].copy()  
            return np.tile(np.expand_dims(state, axis=0), (len(env._morphology), 1))
        else:
            if self.include_end:
                # return np.concatenate((
                #     env._physics.data.xpos[-len(env._morphology):].copy(),
                #     env._physics.data.site_xpos[env._morphology.end_site_indices].copy()
                # ), axis=1)
                return env._physics.data.site_xpos[env._morphology.end_site_indices].copy()
            else:
                return env._physics.data.xpos[-len(env._morphology):].copy()

class VarEA_Line_Boxes(VarEA_Line):

    def _get_state(self, env):
        if self.global_state:
            if self.include_end:
                state = np.concatenate((
                    env._physics.data.xpos[1].copy(),
                    env._physics.data.xpos[2].copy(),
                    env._physics.data.site_xpos[env._morphology.end_site_indices[-1]].copy()
                ), axis=0)
            else:
                state = np.concatenate((
                    env._physics.data.xpos[1].copy(),
                    env._physics.data.xpos[2].copy(),
                ), axis=0)
            return np.tile(np.expand_dims(state, axis=0), (len(env._morphology), 1))
        else:
            return NotImplementedError
