import torch.nn as nn
import numpy as np
import torch
from torch.autograd import Variable


class MasterSlaveNetwork(nn.Module):
    def __init__(self, config):
        super(MasterSlaveNetwork, self).__init__()

        self.layer_number = config['layer_number']
        self.input_size = config['input_size']
        self.num_actions = config['num_actions']
        self.num_aux_tasks = config['num_aux_tasks']
        self.hidden_size = config['hidden_size']
        self.mins = config['mins']
        self.maxes = config['maxes']

        self.feature_per_aux = int(self.hidden_size/(self.num_aux_tasks+1))

        self.input_hidden = nn.Linear(self.input_size, self.hidden_size)
        if self.layer_number == 2:
            self.hidden_hidden = []
        self.hidden_output = []
        for i in np.arange(self.num_aux_tasks+1):
            if self.layer_number == 2:
                self.hidden_hidden.append(nn.Linear(self.feature_per_aux, self.feature_per_aux))
            self.hidden_output.append(nn.Linear(self.hidden_size, self.num_actions))
        if self.layer_number == 2:
            self.hidden_hidden = nn.ModuleList(self.hidden_hidden)
        self.hidden_output = nn.ModuleList(self.hidden_output)

        self.xavier_init()
        if config['head_activation'] == 'tanh':
            self.head_activation = torch.tanh
        elif config['head_activation'] == 'relu':
            self.head_activation = torch.relu


    def forward(self, features, i):
        output = self.hidden_output[i](features)
        return output

    def get_features(self, input):
        normalized_input = self.normalize_input(input)
        normalized_input = self.to_variable(normalized_input)
        if self.layer_number == 1:
            output = self.head_activation(self.input_hidden(normalized_input))
            return output
        if self.layer_number == 2:
            hidden_1 = torch.relu(self.input_hidden(normalized_input))
            output = []
            for i in np.arange(self.num_aux_tasks + 1):
                if len(hidden_1.shape) == 1:
                    output.append(self.head_activation(self.hidden_hidden[i](hidden_1[self.feature_per_aux*i: self.feature_per_aux*(i+1)])))
                else:
                    output.append(self.head_activation(
                        self.hidden_hidden[i](hidden_1[:, self.feature_per_aux * i: self.feature_per_aux * (i + 1)])))
            if len(hidden_1.shape) == 1:
                return torch.cat(output, 0)
            else:
                return torch.cat(output, 1)


    def xavier_init(self):
        nn.init.xavier_uniform_(self.input_hidden.weight.data)
        nn.init.zeros_(self.input_hidden.bias.data)
        for i in np.arange(self.num_aux_tasks + 1):
            nn.init.xavier_uniform_(self.hidden_output[i].weight.data)
            nn.init.zeros_(self.hidden_output[i].bias.data)
        if self.layer_number == 2:
            for i in np.arange(self.num_aux_tasks + 1):
                nn.init.xavier_uniform_(self.hidden_hidden[i].weight.data)
                nn.init.zeros_(self.hidden_hidden[i].bias.data)

    def to_variable(self, input, dtype=torch.FloatTensor):
        if isinstance(input, Variable):
            return input
        output = self.adjust_type(input, dtype)
        output = dtype(torch.from_numpy(output))
        return Variable(output)

    def tensor(self, input, dtype=torch.FloatTensor):
        output = dtype(torch.from_numpy(self.adjust_type(input, dtype)))
        return output

    def adjust_type(self, input, torch_type):
        if torch_type == torch.FloatTensor:
            return np.asarray(input, dtype=np.float32)
        if torch_type == torch.LongTensor:
            return np.asarray(input, dtype=np.int64)
        print('error!')

    def normalize_input(self, input):
        output = 2 * (input - self.mins[:self.input_size]) / (self.maxes[:self.input_size] - self.mins[:self.input_size]) - 1
        return output

    def reset_aux_weights(self, i_replace):
        for i in i_replace:
            nn.init.xavier_uniform_(self.input_hidden.weight.data[(i+1) * self.feature_per_aux:(i+2) * self.feature_per_aux, :])
            nn.init.xavier_uniform_(self.hidden_output[i+1].weight.data)
            if self.layer_number == 2:
                nn.init.xavier_uniform_(self.hidden_hidden[i+1].weight.data)
            for o in np.arange(self.num_aux_tasks+1):
                nn.init.xavier_uniform_(self.hidden_output[o].weight.data[:, (i+1) * self.feature_per_aux: (i+2) * self.feature_per_aux])

    def get_stable_rank(self):
        mat = self.input_hidden.weight.data.cpu().detach().numpy()
        stable_rank = np.linalg.norm(mat) ** 2 / np.linalg.norm(mat, 2) ** 2
        return stable_rank

    # def zero_grad_preserved_features(self, preserved_features):
    #     if preserved_features is None:
    #         return
    #     self.input_hidden.weight.grad[preserved_features, :] *= 0