from itertools import product

import numpy as np
import torch
import torch.nn as nn
from torch.nn.parallel import parallel_apply
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.scatter_gather import gather

from hyper.ghn_modules import MLP_GHN, MlpNetwork

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer


class hyperActor(nn.Module):

    def __init__(self,
                act_dim,
                obs_dim,
                meta_batch_size = 1,
                device = "cpu",
                architecture_sampling_mode = "uniform",
                multi_gpu = False,
                ):
        super().__init__()

        self.act_dim = act_dim
        self.obs_dim = obs_dim
        self.meta_batch_size = meta_batch_size
        self.architecture_sampling_mode = architecture_sampling_mode
        self.multi_gpu = multi_gpu


        # initialize all devices for parallelization on multiple GPUs
        self._initialize_devices(device)

        # initialize all list of shape and architecture indices
        self._initialize_shape_arch_inidices()

        # initialize all data required for architecture sampling
        self._initialize_architecture_smapling_data()

        # initialize the GHN
        self._initialize_ghn(self.obs_dim, self.act_dim)
        



    def _initialize_architecture_smapling_data(self):
        ''' Initializes all the data required for architecture sampling
        Only uniform sampling is used for NiT
        '''

        if self.architecture_sampling_mode == "uniform":

            pass

        else:
            raise NotImplementedError()

        self.sampled_indices = None





    def _initialize_shape_arch_inidices(self):
        ''' Creates:
            1. list_of_arcs: list of all possible architectures, sorted by total number of parameters
            2. list of shape indicators: list of shape indicators for each architecture, that can be used as an input to the GHN
            3. list of arc indices: list of indices of the architectures in list_of_arcs, can be used to sample architectures later
        '''

        self.list_of_arcs = [
            (256, 256, 256),
        ]

        self.list_of_arcs.sort(key = lambda x:self.get_params(x))

        self._initialize_shape_inds()

        self.list_of_arc_indices = np.arange(len(self.list_of_arcs))
        self.all_models = [MlpNetwork(fc_layers=self.list_of_arcs[index], inp_dim = self.obs_dim, out_dim = self.act_dim) for index in self.list_of_arc_indices]
        # shuffle the list of arcs indices
        np.random.shuffle(self.list_of_arc_indices)


    def _initialize_shape_inds(self):
        ''' Creates:
            1. list_of_shape_inds: list of shape indicators for each architecture, that can be used as an input to the GHN
            2. list_of_shape_inds_lenths: list of lengths of each shape indicator, needed since the shape indicators are not all the same length
        '''

        self.list_of_shape_inds = []
        for arc in self.list_of_arcs:
            shape_ind = [torch.tensor(0).type(torch.FloatTensor).to(self.device)]
            for layer in arc:
                shape_ind.append(torch.tensor(layer).type(torch.FloatTensor).to(self.device))
                shape_ind.append(torch.tensor(layer).type(torch.FloatTensor).to(self.device))
            shape_ind.append(torch.tensor(self.act_dim).type(torch.FloatTensor).to(self.device))
            shape_ind.append(torch.tensor(self.act_dim).type(torch.FloatTensor).to(self.device))
            shape_ind = torch.stack(shape_ind).view(-1,1)
            self.list_of_shape_inds.append(shape_ind) 

        self.list_of_shape_inds_lenths = [x.squeeze().numel() for x in self.list_of_shape_inds]
        max_possible_shape_ind_len = 1 + (5 * 2) + 2  # 13
        self.shape_inds_max_len = max_possible_shape_ind_len
        self.arch_max_len = 5  # Max supported hidden layers
        for i in range(len(self.list_of_shape_inds)):
            current_len = self.list_of_shape_inds[i].shape[0]
            num_pad = (self.shape_inds_max_len - current_len)
            if num_pad > 0:
                self.list_of_shape_inds[i] = torch.cat([self.list_of_shape_inds[i], torch.tensor(-1).to(self.device).repeat(num_pad,1)], 0)
        self.list_of_shape_inds = torch.stack(self.list_of_shape_inds)
        self.list_of_shape_inds = self.list_of_shape_inds.reshape(len(self.list_of_shape_inds),self.shape_inds_max_len)


    def _initialize_devices(self, device):
        ''' Inititalize all devices since we are using multiple GPUs. device_model_list can be used later to assign models to devices quickly
        '''
        if self.multi_gpu:
            self.device = torch.device("cuda:0")            
            
            self.all_devices = [torch.device('cuda:{}'.format(i)) for i in range(torch.cuda.device_count())]
            self.num_current_models_per_device = int(self.meta_batch_size / len(self.all_devices)) 
            self.device_model_list = []
            for device in self.all_devices:
                self.device_model_list.extend([device for i in range(self.num_current_models_per_device)])
        else:
            self.device = device



    def _initialize_ghn(self, obs_dim, act_dim):
        ''' Initialize the GHN that takes in the shape indicators and outputs weights for that corresponding architecture
        '''

        config = {}
        config['max_shape'] = (512, 512, 1, 1)
        config['num_classes'] = act_dim  # Changed from 2 * act_dim (for BC, no std needed)
        config['num_observations'] = obs_dim
        config['weight_norm'] = True
        config['ve'] = 1 > 1
        config['layernorm'] = True
        config['hid'] = 16
        self.ghn_config = config
        self.ghn = MLP_GHN(**config,
                    debug_level=0, device=self.device).to(self.device)  



    def get_params(self, net):
        ''' Get the number of parameters in a MLP network architecture
        '''
        ct = 0
        ct += ((self.obs_dim + 1) *net[0])
        for i in range(len(net)-1):
            ct += ((net[i] + 1) * net[i+1])
        ct += ((net[-1] +1) * self.act_dim)
        return ct            

    def sample_arc_indices(self, mode = 'uniform'):
        ''' Sample the indices of the architectures to be used for the current model
            Only uniform sampling is supported for NiT
        '''
        # if mode == 'biased':
        #     # Biased sampling - commented out for NiT
        #     self.sampled_indices = np.random.choice(self.list_of_arc_indices, self.meta_batch_size, p = self.arch_sampling_probs, replace=False)
        # elif mode == 'sequential':
        #     # Sequential sampling - commented out for NiT
        #     self.sampled_indices = self.list_of_arc_indices[self.current_model_indices]
        #     self.current_model_indices += self.meta_batch_size
        #     if max(self.current_model_indices) >= len(self.list_of_arc_indices):
        #         self.current_model_indices = np.arange(self.meta_batch_size)
        #         np.random.shuffle(self.list_of_arc_indices)
        if mode == 'uniform':
            # Uniform sampling: all architectures have equal probability
            # Create balanced indices: repeat each architecture equally, then add remainder
            num_complete_cycles = self.meta_batch_size // len(self.list_of_arc_indices)
            remaining = self.meta_batch_size % len(self.list_of_arc_indices)

            sampled_indices = []
            for _ in range(num_complete_cycles):
                sampled_indices.extend(self.list_of_arc_indices)

            if remaining > 0:
                # Add remaining indices uniformly
                remaining_indices = np.random.choice(self.list_of_arc_indices, remaining, replace=False)
                sampled_indices.extend(remaining_indices)

            # Shuffle to avoid always having same order
            np.random.shuffle(sampled_indices)
            self.sampled_indices = np.array(sampled_indices)
            # self.sampled_indices = np.random.choice(self.list_of_arc_indices, self.meta_batch_size, replace=False)
        else:
            raise NotImplementedError(f"Sampling mode '{mode}' not supported. Use 'uniform'")



    def set_graph(self, indices_vector, shape_ind_vec):
        ''' Set the graph to be used by the GHN. We can do this only by passing the indices of the
            architectures we want to use and the shape indicators for those architectures. Then we estimate the
            weights for those architectures and set it to the current model
        '''

        self.sampled_indices = indices_vector
        # self.sampled_shape_inds = shape_ind_vec.view(-1)[shape_ind_vec.view(-1) != -1].unsqueeze(-1)
        self.current_shape_inds_vec = [self.list_of_shape_inds[index] for index in self.sampled_indices]
        self.list_of_sampled_shape_inds = [self.current_shape_inds_vec[k][:self.list_of_shape_inds_lenths[index]] for k,index in enumerate(self.sampled_indices)]
        self.sampled_shape_inds = torch.cat(self.list_of_sampled_shape_inds).view(-1,1)
        assert (self.sampled_shape_inds == shape_ind_vec.view(-1)[shape_ind_vec.view(-1) != -1].unsqueeze(-1)).all(), 'Shape inds do not match'
        self.current_model = [self.all_models[i] for i in self.sampled_indices]
        self.current_archs = torch.tensor([list(self.list_of_arcs[index]) + [0]*(5-len(self.list_of_arcs[index])) for index in self.sampled_indices]).to(self.device)
        _, embeddings = self.ghn(self.current_model, return_embeddings=True, shape_ind = self.sampled_shape_inds)


    def change_graph(self, repeat_sample = False):
        ''' Estimate the weights for the current models.
            If repeat_sample is True, then we re-estimate the weights for the same architectures (i.e. current models does not change)
            If repeat_sample is False, then we sample new architectures (i.e. change the current models) and estimate the weights for those architectures
        '''
        if not repeat_sample:
            self.sample_arc_indices(mode = self.architecture_sampling_mode)

            self.current_shape_inds_vec = [self.list_of_shape_inds[index] for index in self.sampled_indices]
            self.list_of_sampled_shape_inds = [self.current_shape_inds_vec[k][:self.list_of_shape_inds_lenths[index]] for k,index in enumerate(self.sampled_indices)]

            self.current_archs = torch.tensor([list(self.list_of_arcs[index]) + [0]*(self.arch_max_len-len(self.list_of_arcs[index])) for index in self.sampled_indices]).to(self.device)
            self.current_model = [self.all_models[i] for i in self.sampled_indices]

            # self.param_counts = [self.get_params(self.list_of_arcs[index]) for index in self.sampled_indices]
            # self.capacities = [get_capacity(self.list_of_arcs[index], self.obs_dim, self.act_dim) for index in self.sampled_indices]

        if self.multi_gpu:
            self.multi_ghns = replicate(self.ghn, self.all_devices)
            for i, device in enumerate(self.all_devices):
                sampled_shape_inds = torch.cat(self.list_of_sampled_shape_inds[i*self.num_current_models_per_device:(i+1)*self.num_current_models_per_device]).view(-1,1)
                _, embeddings = self.multi_ghns[i](self.current_model[i*self.num_current_models_per_device:(i+1)*self.num_current_models_per_device], return_embeddings=True, shape_ind = sampled_shape_inds.to(device))
        else:
            self.sampled_shape_inds = torch.cat(self.list_of_sampled_shape_inds).view(-1,1)
            _, embeddings = self.ghn(self.current_model, return_embeddings=True, shape_ind = self.sampled_shape_inds)


    def forward(self, state, track=True):
        ''' Do a forward pass through the current models. We split the state batch into chunks of size batch_per_net and pass it through each of the current models
            track: if True, we track the shape indicators, architectures and indices of the current models.
                We store this information if it is needed for architecture conditioned value functions
        '''
        batch_per_net = int(state.shape[0]//len(self.current_model))

        if track:
            self.shape_ind_per_state_dim = torch.cat([self.current_shape_inds_vec[i].repeat(batch_per_net,1) for i in range(len(self.current_model))])
            self.arch_per_state_dim = torch.cat([self.current_archs[i].repeat(batch_per_net,1) for i in range(len(self.current_model))])
            self.sampled_indices_per_state_dim = torch.cat([torch.tensor([self.sampled_indices[i]]).repeat(batch_per_net) for i in range(len(self.current_model))])

        if self.multi_gpu:
            actions = gather(parallel_apply(self.current_model, [state[i*batch_per_net:(i+1)*batch_per_net].to(self.device_model_list[i]) for i in range(len(self.current_model))]), self.device)
        else:
            actions = torch.cat(parallel_apply(self.current_model, [state[i*batch_per_net:(i+1)*batch_per_net] for i in range(len(self.current_model))]))

        return actions    

