from dataclasses import dataclass
import os
import numpy as np
import itertools
import torch
import torch_geometric

from optimal_agents.morphology import Morphology
from optimal_agents.utils.loader import get_env, get_morphology
from optimal_agents.utils.trainer import run_train
from optimal_agents.utils.tester import eval_policy
from optimal_agents.policies import random_policies
from optimal_agents.policies import predictive_models

import random
import time

@dataclass
class Individual:
    morphology: Morphology
    fitness: float
    start_index: int
    end_index: int
    index: int

class EmpowermentEA_Base(object):

    def __init__(self, params, eval_ep=8, nge_mutation=False, save_freq=10, 
                       pruning_lr=0.005, pruning_batch_size=128, pruning_n_epochs=10,
                       global_state=False, num_freqs=2, num_phases=2, sample_freq=50,
                       state_noise=0.0, keep_percent=0.0, reset_freq=-1, eval_noise=False, seperate_eval=True,
                       matching_noise=False, random_policy="CosinePolicy", pruning_arch=[192, 192, 192], classifier=None,
                       include_segments=False, include_start_state=False, include_end=False, action_ent_coef=0.0):

        # Save the model parameters and mutation parameters
        self.params = params
        self.mutation_kwargs = params['mutation_args']
        self.nge_mutation = nge_mutation
        self.keep_percent = keep_percent
        self.action_ent_coef = action_ent_coef

        # Save information for data collection
        self.eval_ep = eval_ep
        self.save_freq = save_freq
        self.matching_noise = matching_noise
        self.state_noise = state_noise
        self.eval_noise = eval_noise
        self.global_state = global_state
        self.seperate_eval = seperate_eval
        self.include_segments = include_segments
        self.include_start_state = include_start_state
        self.include_end = include_end

        # Save info for model training
        self.lr = pruning_lr
        self.batch_size = pruning_batch_size
        self.n_epochs = pruning_n_epochs
        self.reset_freq = reset_freq

        # Save the classifer type
        self.classifier_cls = vars(predictive_models)[classifier]
        self.pruning_arch = pruning_arch

        # Construct the Random Policy
        self.max_nodes = self.mutation_kwargs['max_nodes'] if 'max_nodes' in self.mutation_kwargs else 12
        random_policy_kwargs = {'sample_freq': sample_freq, 'num_freqs' : num_freqs, 'num_phases' : num_phases}
        random_policy_cls = vars(random_policies)[random_policy]
        self.policy = random_policy_cls(self.max_nodes, **random_policy_kwargs)

        # Save data for policy training
        if ' ' in self.params['env']:
            self.eval_envs = self.params['env'].split(' ')
            self.params['env'] = self.eval_envs[0]
        else:
            self.eval_envs = [self.params['env']]

        # Clear the env wrapper from the params. If we want one later, we can re-add in _build_model
        self.policy_learning_env_wrapper = self.params['env_wrapper']
        self.policy_learning_env_wrapper_args = self.params['env_wrapper_args']
        self.params['env_wrapper'] = None
        self.params['env_wrapper_args'] = dict()

        self.buffer = []
        self.eval_buffer = []
        self.population = []
        
        self._build_model()

    def _build_model(self):
        raise NotImplementedError

    def preprocess_batch(self, data, noise=True):
        if noise and self.state_noise > 0:
            if self.matching_noise:
                # Need to evaluate noise per batch
                noisy_data = []
                for data_pt in data:
                    x = data_pt.x.clone()
                    noise = self.state_noise * torch.randn(1, 3) # All noise is on 3 vector x, y z
                    noise = noise.repeat(1, self.noise_mult)
                    x[:, :3 * self.noise_mult] += noise # Add noise to each limb differently
                    noisy_data.append(torch_geometric.data.Data(x=x, edge_index=data_pt.edge_index, edge_attr=data_pt.edge_attr, y=data_pt.y))
                return torch_geometric.data.Batch.from_data_list(noisy_data)
            else:
                batch = torch_geometric.data.Batch.from_data_list(data)
                x = batch.x.clone()
                noise = self.state_noise * torch.randn(x.shape[0], 3*self.noise_mult)
                x[:, :3*self.noise_mult] += noise # Add noise to each limb different
                batch.x = x
                return batch
        else:
            return torch_geometric.data.Batch.from_data_list(data)

    def update_model(self):
        # Train the model
        for epoch in range(self.n_epochs):
            num_data_pts = 0
            num_correct = 0
            perm = np.random.permutation(len(self.buffer))
            num_full_batches = len(perm) // self.batch_size
            for i in range(num_full_batches + 1):
                if i != num_full_batches:
                    inds = perm[i*self.batch_size:(i+1)*self.batch_size]
                else:
                    inds = perm[i*self.batch_size:]
                if len(inds) == 0:
                    continue
                batch = self.preprocess_batch([self.buffer[ind] for ind in inds])
                self.optim.zero_grad()
                pred = self.model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)
                loss = self.criterion(pred, batch.y)
                loss.backward()
                self.optim.step()
                with torch.no_grad():
                    pred_labels = torch.argmax(pred, dim=1)
                    num_correct += torch.sum(pred_labels == batch.y).item()
                    num_data_pts += pred.shape[0]
            print("Epoch", epoch, "Loss", loss.item(), "Acc", num_correct / num_data_pts)

    def update_population(self):
        # Iterate through every individual in the population and determine the fitness.
        for individual in self.population:
            # Update the fitness by constructing data
            if self.seperate_eval:
                data = self.eval_buffer[individual.start_index:individual.end_index]
            else:
                data = self.buffer[individual.start_index:individual.end_index]
            batch = self.preprocess_batch(data, noise=self.eval_noise)
            with torch.no_grad():
                pred = self.model(batch.x, batch.edge_index, batch.edge_attr, batch.batch) # No noise added during eval. This could be bad.
                fitness = -1*self.criterion(pred, batch.y).item()
                fitness = np.power(np.log(0.5*individual.morphology.num_joints), self.action_ent_coef) * (np.log(self.policy.num_actions) + fitness)
            individual.fitness = fitness

    def learn(self, path, population_size, num_generations):
        # Generate the initial population
        os.makedirs(path, exist_ok=True)
        start_time = time.time()

        for _ in range(population_size*2):
            self.add_morphology(get_morphology(self.params))

        for gen_idx in range(num_generations):
            self.update_model()
            self.update_population()
            self.population.sort(key=lambda individual: individual.fitness, reverse=True)

            with open(os.path.join(path, "gen" + str(gen_idx) + ".txt"), "w+") as f:
                for individual in self.population:
                    f.write(str(individual.fitness) + " " + str(individual.index) + "\n")

            if (gen_idx + 1) % self.save_freq == 0:
                gen_path = os.path.join(path, 'gen_' + str(gen_idx))
                os.makedirs(gen_path, exist_ok=True)
                self.params.save(gen_path)
                for i, individual in enumerate(self.population):
                    individual.morphology.save(os.path.join(gen_path, str(i) + '.morphology.pkl'))

            # If were at the end, exit so we don't add extra morphologies.
            if gen_idx + 1 == num_generations:
                break

            # Take population size samples from the population to construct new morphologies.
            for i in range(population_size):
                if i >= 3 and self.keep_percent > 0 and int(self.keep_percent*len(self.population)) > 5:
                    morphology = random.choice(self.population[:max(int(len(self.population)*self.keep_percent),1)]).morphology
                else:
                    morphology = self.population[i].morphology # sample a new morphology according to fitness level
                if self.nge_mutation:
                    new_morphology = morphology.mutate_nge(**self.mutation_kwargs)
                else:
                    new_morphology = morphology.mutate(**self.mutation_kwargs)

                self.add_morphology(new_morphology)

            # If reset freq, reset the model
            if self.reset_freq > 0 and gen_idx % self.reset_freq == 0 and gen_idx > 1:
                self._build_model()

            print("Finished Gen", gen_idx)
            # for individual in self.population:
            #     print(individual.fitness, individual.morphology.num_joints)

        print("=====================================")
        print("TIME:", time.time() - start_time)
        print("=====================================")

        assert self.population[0].fitness == max([individual.fitness for individual in self.population]), "Error, didn't get max fitness individual"

        morphology = self.population[0].morphology
        env = get_env(self.params, morphology=morphology)
        env.reset()
        frames = []
        action_dim = self._get_action_dim(morphology)
        actions, _ = self.policy.step(400)
        action_idx, done = 0, False
        while not done:
            frames.append(env.render(mode='rgb_array'))
            ac = actions[action_idx, :action_dim]
            action_idx += 1
            _, _, done, _ = env.step(ac)

        import imageio
        imageio.mimsave(os.path.join(path, 'best.gif'), frames[::3], subrectangles=True, duration=0.05)
        del env

        # Now, train a policy on the given task.
        # Get the best morphology
        del self.params['env_args']['time_limit']
        if not self.params['arena'] is None and 'Terrain' in self.params['arena']:
            self.params['arena'] = None

        self.params['env_wrapper'] = self.policy_learning_env_wrapper
        self.params['env_wrapper_args'] = self.policy_learning_env_wrapper_args
        for env_name in self.eval_envs:
            self.params['env'] = env_name
            model, _ = run_train(self.params, morphology=morphology, path=path)
            env = get_env(self.params, morphology=morphology)
            avg_reward, frames = eval_policy(model, env, num_ep=1, deterministic=True, verbose=1, gif=True, render=True)
            imageio.mimsave(os.path.join(path, 'best_trained_' + env_name + '.gif'), frames[::3], subrectangles=True, duration=0.05)

    def add_morphology(self, morphology):
        try:
            eval_env = get_env(self.params, morphology=morphology)
            morphology_obs = eval_env.get_morphology_obs(include_segments=self.include_segments)

            states = []
            labels = []
            action_dim = self._get_action_dim(morphology)
            num_ep = self.eval_ep
            for i in range(num_ep):
                actions, label = self.policy.step(500) # Arbitrarily set number of actions.
                done = False
                action_idx = 0
                obs = eval_env.reset()
                if self.include_start_state:
                    start_state = self._get_state(eval_env)
                while not done:
                    obs, _, done, _ = eval_env.step(actions[action_idx, :action_dim])
                    action_idx += 1
                # Get the final state and morphology as data
                if self.include_start_state:
                    states.append(np.concatenate((start_state, self._get_state(eval_env)), axis=1))
                else:
                    states.append(self._get_state(eval_env))
                labels.append(label[:action_dim])
            
            data = self._preprocess_data(morphology_obs, states, labels)
            if len(data) == 0:
                return
            
            if self.seperate_eval:
                eval_idx = int(2/3 * len(data))
                train_data = data[:eval_idx]
                eval_data = data[eval_idx:]

                self.buffer.extend(train_data)
                data_start_idx = len(self.eval_buffer)
                self.eval_buffer.extend(eval_data)
                data_end_idx = len(self.eval_buffer)
            else:
                data_start_idx = len(self.buffer)
                self.buffer.extend(data)
                data_end_idx = len(self.buffer)

            individual = Individual(morphology, -np.inf, data_start_idx, data_end_idx, len(self.population))
            self.population.append(individual)
        except:
            return
    
    def _get_action_dim(self, morphology):
        raise NotImplementedError

    def _preprocess_data(self, morphology_obs, states, labels):
        raise NotImplementedError

class EmpowermentEA_FC(EmpowermentEA_Base):

    def _build_model(self):
        if hasattr(self, "model"):
            del self.model
        if hasattr(self, "optim"):
            del self.optim
        if hasattr(self, "criterion"):
            del self.criterion

        self.noise_mult = 1 # This is just for xpos
        if self.include_end:
            self.noise_mult += 1 # If we include the end
        if self.include_start_state:
            self.noise_mult *= 2 # If we include the start state

        in_dim = 0
        # Add values for the nodes
        node_in_dim = 5
        if self.include_segments:
            node_in_dim += 6
        in_dim += self.max_nodes * node_in_dim
        edge_in_dim = 6
        in_dim += (self.max_nodes-1) * edge_in_dim # Graph is tree, have edge features for all but one node.
        
        if self.global_state:
            in_dim += 3 * self.noise_mult # If global state, done one thing.
        else:
            self.noise_mult *= self.max_nodes # State per node
            in_dim += 3*self.noise_mult

        # Perhaps adjust to use morphology num joints for multi way prediciton.
        num_joints = get_morphology(self.params).num_joints

        self.model = self.classifier_cls(in_dim, self.policy.num_actions, num_joints, net_arch=self.pruning_arch)
        self.multi_head = self.model.num_heads > 1
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=self.policy.num_actions)
        self.optim = torch.optim.Adam(self.model.parameters(), lr=self.lr)

    def _get_action_dim(self, morphology):
        return morphology.num_joints

    def _get_state(self, env):
        if self.global_state:
            if self.include_end:
                state = np.concatenate((
                    env._physics.data.xpos[-len(env._morphology)].copy(),
                    env._physics.data.site_xpos[env._morphology.end_site_indices[0]].copy() # The first end index is the root.
                ), axis=0)
            else:
                state = env._physics.data.xpos[-len(env._morphology)].copy()  
            return np.expand_dims(state.flatten(), axis=0)
        else:
            if self.include_end:
                return np.expand_dims(np.concatenate((
                        env._physics.data.xpos[-len(env._morphology):].copy(),
                        env._physics.data.site_xpos[env._morphology.end_site_indices].copy()
                    ), axis=1).flatten(), axis=0)
            else:
                return np.expand_dims(env._physics.data.xpos[-len(env._morphology):].copy().flatten(), axis=0)

    def _preprocess_data(self, morphology_obs, states, labels):
        data = []
        preproc_morph_obs = np.expand_dims(np.concatenate((morphology_obs['x'].flatten(), morphology_obs['edge_attr'].flatten()), axis=0), axis=0)
        for state, label in zip(states, labels):
            if self.multi_head:
                y = torch.from_numpy(np.expand_dims(label, axis=0))
            else:
                y = torch.from_numpy(np.array([label[0]]))
            x = torch.from_numpy(np.concatenate((state, preproc_morph_obs), axis=1).astype(np.float32))
            data.append(torch_geometric.data.Data(x=x, y=y))
        return data

class EmpowermentEA_Node(EmpowermentEA_Base):

    def _build_model(self):
        if hasattr(self, "model"):
            del self.model
        if hasattr(self, "optim"):
            del self.optim
        if hasattr(self, "criterion"):
            del self.criterion

        # Set to node wrapper
        self.params['env_wrapper'] = "NodeWrapper"

        self.noise_mult = 1 # This is just for xpos
        if self.include_end:
            self.noise_mult += 1 # If we include the end
        if self.include_start_state:
            self.noise_mult *= 2 # If we include the start state
        
        in_dim = 11 # Morphology is length 11
        if self.include_segments: # Add 6 to each morphology node
            in_dim += 6
        in_dim += 3 * self.noise_mult # Add 3 per state component.

        self.model = self.classifier_cls(in_dim, self.policy.num_actions, net_arch=self.pruning_arch) 
        self.criterion = torch.nn.CrossEntropyLoss(ignore_index=self.policy.num_actions)
        self.optim = torch.optim.Adam(self.model.parameters(), lr=self.lr)

    def _get_action_dim(self, morphology):
        return len(morphology)
    
    def _get_state(self, env):
        if self.global_state:
            if self.include_end:
                state = np.concatenate((
                    env.env._physics.data.xpos[-len(env.env._morphology)].copy(),
                    env.env._physics.data.site_xpos[env.env._morphology.end_site_indices[0]].copy() # The first end index is the root.
                ), axis=0)
            else:
                state = env.env._physics.data.xpos[-len(env.env._morphology)].copy()  
            return np.tile(np.expand_dims(state, axis=0), (len(env.env._morphology), 1))
        else:
            if self.include_end:
                return np.concatenate((
                    env.env._physics.data.xpos[-len(env.env._morphology):].copy(),
                    env.env._physics.data.site_xpos[env.env._morphology.end_site_indices].copy()
                ), axis=1) 
            else:
                return env.env._physics.data.xpos[-len(env.env._morphology):].copy()

    def _preprocess_data(self, morphology_obs, states, labels):
        data = []
        for state, label in zip(states, labels):
            edge_index = np.concatenate((morphology_obs['edge_index'], np.roll(morphology_obs['edge_index'] , 1, axis=1)), axis=0)
            y = torch.from_numpy(label)
            y[0] = self.policy.num_actions # Set the value here so that we ignore predicting the root node with cross entropy loss
            x = torch.from_numpy(np.concatenate((state, morphology_obs['x']), axis=1).astype(np.float32))
            edge_index = torch.from_numpy(edge_index).t().contiguous()
            data.append(torch_geometric.data.Data(x=x, edge_index=edge_index, y=y))
        return data

class EmpowermentEA_Meta(EmpowermentEA_Base):

    def _build_model(self):
        if hasattr(self, "model"):
            del self.model
        if hasattr(self, "optim"):
            del self.optim
        if hasattr(self, "criterion"):
            del self.criterion

        self.noise_mult = 1 # This is just for xpos
        if self.include_end:
            self.noise_mult += 1 # If we include the end
        if self.include_start_state:
            self.noise_mult *= 2 # If we include the start state

        node_in_dim = 5 + 3*self.noise_mult
        if self.include_segments:
            node_in_dim += 6
        edge_in_dim = 6 + 1

        self.model = self.classifier_cls(node_in_dim, edge_in_dim, self.policy.num_actions, net_arch=self.pruning_arch) 
        self.criterion = torch.nn.CrossEntropyLoss()
        self.optim = torch.optim.Adam(self.model.parameters(), lr=self.lr)
    
    def _get_action_dim(self, morphology):
        return morphology.num_joints

    def _get_state(self, env):
        if self.global_state:
            if self.include_end:
                state = np.concatenate((
                    env._physics.data.xpos[-len(env._morphology)].copy(),
                    env._physics.data.site_xpos[env._morphology.end_site_indices[0]].copy() # The first end index is the root.
                ), axis=0)
            else:
                state = env._physics.data.xpos[-len(env._morphology)].copy()  
            return np.tile(np.expand_dims(state, axis=0), (len(env._morphology), 1))
        else:
            if self.include_end:
                return np.concatenate((
                    env._physics.data.xpos[-len(env._morphology):].copy(),
                    env._physics.data.site_xpos[env._morphology.end_site_indices].copy()
                ), axis=1) 
            else:
                return env._physics.data.xpos[-len(env._morphology):].copy()

    def _preprocess_data(self, morphology_obs, states, labels):
        data = []
        for state, label in zip(states, labels):
            edge_index = np.concatenate((morphology_obs['edge_index'], np.roll(morphology_obs['edge_index'] , 1, axis=1)), axis=0)
            y = torch.from_numpy(np.tile(label, 2))
            x = torch.from_numpy(np.concatenate((state, morphology_obs['x']), axis=1).astype(np.float32))
            edge_attr = np.tile(morphology_obs['edge_attr'].astype(np.float32), (2, 1))
            edge_attr = np.concatenate((edge_attr, (edge_index[:, 0:1] > edge_index[:, 1:2]).astype(np.float32)), axis=1)
            edge_attr = torch.from_numpy(edge_attr)
            edge_index = torch.from_numpy(edge_index).t().contiguous()
            data.append(torch_geometric.data.Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y))
        return data

class EmpowermentEA_Line(EmpowermentEA_Base):

    def _build_model(self):
        if hasattr(self, "model"):
            del self.model
        if hasattr(self, "optim"):
            del self.optim
        if hasattr(self, "criterion"):
            del self.criterion

        self.noise_mult = 1 # This is just for xpos
        if self.include_end:
            self.noise_mult += 1 # If we include the end
        if self.include_start_state:
            self.noise_mult *= 2 # If we include the start state
        self.noise_mult *= 2

        in_dim = 2*5 + 6 # Morph Node: 5, Morph Edge: 6
        if self.include_segments:
            in_dim += 2*6
        in_dim += 3*self.noise_mult

        self.model = self.classifier_cls(in_dim, self.policy.num_actions, net_arch=self.pruning_arch) 
        self.criterion = torch.nn.CrossEntropyLoss()
        self.optim = torch.optim.Adam(self.model.parameters(), lr=self.lr)
    
    def _get_action_dim(self, morphology):
        return morphology.num_joints

    def _get_state(self, env):
        if self.global_state:
            if self.include_end:
                state = np.concatenate((
                    env._physics.data.xpos[-len(env._morphology)].copy(),
                    env._physics.data.site_xpos[env._morphology.end_site_indices[0]].copy() # The first end index is the root.
                ), axis=0)
            else:
                state = env._physics.data.xpos[-len(env._morphology)].copy()  
            return np.tile(np.expand_dims(state, axis=0), (len(env._morphology), 1))
        else:
            if self.include_end:
                return np.concatenate((
                    env._physics.data.xpos[-len(env._morphology):].copy(),
                    env._physics.data.site_xpos[env._morphology.end_site_indices].copy()
                ), axis=1) 
            else:
                return env._physics.data.xpos[-len(env._morphology):].copy()

    def _preprocess_data(self, morphology_obs, states, labels):
        edges = morphology_obs['edge_index']
        line_edges = []
        for i in range(len(edges) - 1):
            for j in range(i+1, len(edges)):
                # Considers all edges
                node_set_1 = set(list(edges[i]))
                node_set_2 = set(list(edges[j]))
                if len(node_set_1.intersection(node_set_2)) > 0:
                    # These "body parts" share a connection.
                    # Need to update graph based on JOINT graph.
                    joint_id_1 = edges[i][1] - 1 # Get Child ID - 1
                    joint_id_2 = edges[j][1] - 1 # Get Child ID - 1
                    line_edges.append([joint_id_1 ,joint_id_2])
                    line_edges.append([joint_id_2, joint_id_1]) # Add both for undirected graph
        if len(line_edges) == 0:
            return []
        x = morphology_obs['x']
        data = []
        for state, y in zip(states, labels):
            attr = np.zeros((edges.shape[0], 2*state.shape[1] + 2*x.shape[1] + morphology_obs['edge_attr'].shape[1]))
            for parent_id, child_id in edges:
                cur_attr = np.concatenate((
                    state[parent_id], state[child_id],
                    x[parent_id], x[child_id],
                    morphology_obs['edge_attr'][child_id - 1] # joint is associated with the child.
                ), axis=0)
                
                attr[child_id - 1] = cur_attr

            data_pt = torch_geometric.data.Data(x=torch.from_numpy(attr.astype(np.float32)), 
                                                edge_index=torch.from_numpy(np.array(line_edges, dtype=np.long)).t().contiguous(),
                                                y=torch.from_numpy(y), num_nodes=len(x)-1)
            data.append(data_pt)
        
        return data

class EmpowermentEA_Node_Arm(EmpowermentEA_Node):
    
    def _get_state(self, env):
        if self.global_state:
            if self.include_end:
                state = np.concatenate((
                    env.env._physics.data.xpos[-1].copy(),
                    env.env._physics.data.site_xpos[env.env._morphology.end_site_indices[-1]].copy() # The Last index is the end
                ), axis=0)
            else:
                state = env.env._physics.data.xpos[-1].copy()  
            return np.tile(np.expand_dims(state, axis=0), (len(env.env._morphology), 1))
        else:
            if self.include_end:
                return np.concatenate((
                    env.env._physics.data.xpos[-len(env.env._morphology):].copy(),
                    env.env._physics.data.site_xpos[env.env._morphology.end_site_indices].copy()
                ), axis=1) 
            else:
                return env.env._physics.data.xpos[-len(env.env._morphology):].copy()


class EmpowermentEA_Line_Arm(EmpowermentEA_Line):

    def _build_model(self):
        if hasattr(self, "model"):
            del self.model
        if hasattr(self, "optim"):
            del self.optim
        if hasattr(self, "criterion"):
            del self.criterion

        self.noise_mult = 1 # This is just for xpos
        # if self.include_end: # Comment out for only using end.
        #     self.noise_mult += 1 # If we include the end
        if self.include_start_state:
            self.noise_mult *= 2 # If we include the start state
        self.noise_mult *= 2

        in_dim = 2*5 + 6 # Morph Node: 5, Morph Edge: 6
        if self.include_segments:
            in_dim += 2*6
        in_dim += 3*self.noise_mult

        self.model = self.classifier_cls(in_dim, self.policy.num_actions, net_arch=self.pruning_arch) 
        self.criterion = torch.nn.CrossEntropyLoss()
        self.optim = torch.optim.Adam(self.model.parameters(), lr=self.lr)

    def _get_state(self, env):
        if self.global_state:
            if self.include_end:
                # state = np.concatenate((
                #     env._physics.data.xpos[-1].copy(),
                #     env._physics.data.site_xpos[env._morphology.end_site_indices[-1]].copy()
                # ), axis=0)
                state = env._physics.data.site_xpos[env._morphology.end_site_indices[-1]].copy()
            else:
                state = env._physics.data.xpos[-1].copy()  
            return np.tile(np.expand_dims(state, axis=0), (len(env._morphology), 1))
        else:
            if self.include_end:
                # return np.concatenate((
                #     env._physics.data.xpos[-len(env._morphology):].copy(),
                #     env._physics.data.site_xpos[env._morphology.end_site_indices].copy()
                # ), axis=1)
                return env._physics.data.site_xpos[env._morphology.end_site_indices].copy()
            else:
                return env._physics.data.xpos[-len(env._morphology):].copy()

class EmpowermentEA_Line_Boxes(EmpowermentEA_Line):

    def _get_state(self, env):
        if self.global_state:
            if self.include_end:
                state = np.concatenate((
                    env._physics.data.xpos[1].copy(),
                    env._physics.data.xpos[2].copy(),
                    env._physics.data.site_xpos[env._morphology.end_site_indices[-1]].copy()
                ), axis=0)
            else:
                state = np.concatenate((
                    env._physics.data.xpos[1].copy(),
                    env._physics.data.xpos[2].copy(),
                ), axis=0)
            return np.tile(np.expand_dims(state, axis=0), (len(env._morphology), 1))
        else:
            return NotImplementedError

    def _build_model(self):
        if hasattr(self, "model"):
            del self.model
        if hasattr(self, "optim"):
            del self.optim
        if hasattr(self, "criterion"):
            del self.criterion

        self.noise_mult = 2 # This is for the two boxes
        if self.include_end:
            self.noise_mult += 1 # If we include the end
        if self.include_start_state:
            raise ValueError("Not Implemented")
        self.noise_mult *= 2

        in_dim = 2*5 + 6 # Morph Node: 5, Morph Edge: 6
        if self.include_segments:
            in_dim += 2*6
        in_dim += 3*self.noise_mult

        self.model = self.classifier_cls(in_dim, self.policy.num_actions, net_arch=self.pruning_arch) 
        self.criterion = torch.nn.CrossEntropyLoss()
        self.optim = torch.optim.Adam(self.model.parameters(), lr=self.lr)

