import torch.nn as nn
import numpy as np
import torch
from torch.autograd import Variable


class Network(nn.Module):
    def __init__(self, config):
        super(Network, self).__init__()

        self.layer_number = config['layer_number']
        self.input_size = config['input_size']
        self.output_size = config['output_size']
        self.num_actions = config['num_actions']
        self.hidden_size = config['hidden_size']
        self.mins = config['mins']
        self.maxes = config['maxes']

        self.fc_hidden_1 = nn.Linear(self.input_size, self.hidden_size)
        if self.layer_number > 1:
            self.fc_hidden_2 = nn.Linear(self.hidden_size, self.hidden_size)
        if self.layer_number > 2:
            self.fc_hidden_3 = nn.Linear(self.hidden_size, self.hidden_size)
        if self.layer_number > 3:
            self.fc_hidden_4 = nn.Linear(self.hidden_size, self.hidden_size)
        if self.layer_number > 4:
            self.fc_hidden_5 = nn.Linear(self.hidden_size, self.hidden_size)
        if self.layer_number > 5:
            self.fc_hidden_6 = nn.Linear(self.hidden_size, self.hidden_size)
        self.fc_output = nn.Linear(self.hidden_size, self.output_size)
        if config['head_activation'] == 'tanh':
            self.head_activation = torch.tanh
        elif config['head_activation'] == 'relu':
            self.head_activation = torch.relu
        self.xavier_init()


    def forward(self, input):
        output = self.normalize_input(input)
        output = self.to_variable(output)
        if self.layer_number == 1:
            output = self.head_activation(self.fc_hidden_1(output))
        elif self.layer_number == 2:
            output = nn.functional.relu(self.fc_hidden_1(output))
            output = self.head_activation(self.fc_hidden_2(output))
        elif self.layer_number == 3:
            output = nn.functional.relu(self.fc_hidden_1(output))
            output = nn.functional.relu(self.fc_hidden_2(output))
            output = self.head_activation(self.fc_hidden_3(output))
        elif self.layer_number == 4:
            output = nn.functional.relu(self.fc_hidden_1(output))
            output = nn.functional.relu(self.fc_hidden_2(output))
            output = nn.functional.relu(self.fc_hidden_3(output))
            output = self.head_activation(self.fc_hidden_4(output))
        elif self.layer_number == 5:
            output = nn.functional.relu(self.fc_hidden_1(output))
            output = nn.functional.relu(self.fc_hidden_2(output))
            output = nn.functional.relu(self.fc_hidden_3(output))
            output = nn.functional.relu(self.fc_hidden_4(output))
            output = self.head_activation(self.fc_hidden_5(output))
        elif self.layer_number == 6:
            output = nn.functional.relu(self.fc_hidden_1(output))
            output = nn.functional.relu(self.fc_hidden_2(output))
            output = nn.functional.relu(self.fc_hidden_3(output))
            output = nn.functional.relu(self.fc_hidden_4(output))
            output = nn.functional.relu(self.fc_hidden_5(output))
            output = self.head_activation(self.fc_hidden_6(output))
        else:
            print('error in layer number')
        output = self.fc_output(output)
        return output

    def xavier_init(self):
        for layer in self.children():
            nn.init.xavier_uniform_(layer.weight.data)
            nn.init.zeros_(layer.bias.data)

    def to_variable(self, input, dtype=torch.FloatTensor):
        if isinstance(input, Variable):
            return input
        output = self.adjust_type(input, dtype)
        output = dtype(torch.from_numpy(output))
        return Variable(output)

    def tensor(self, input, dtype=torch.FloatTensor):
        output = dtype(torch.from_numpy(self.adjust_type(input, dtype)))
        return output

    def adjust_type(self, input, torch_type):
        if torch_type == torch.FloatTensor:
            return np.asarray(input, dtype=np.float32)
        if torch_type == torch.LongTensor:
            return np.asarray(input, dtype=np.int64)
        print('error!')

    def normalize_input(self, input):
        output = 2 * (input - self.mins[:self.input_size]) / (self.maxes[:self.input_size] - self.mins[:self.input_size]) - 1
        return output

    def reset_output_weights(self, i_replace):
        with torch.no_grad():
            std = torch.sqrt(torch.tensor(2.0/(self.hidden_size + self.output_size)))
            self.fc_output.weight[i_replace, :] = torch.normal(0, std, size=(i_replace.shape[0], self.hidden_size))
            self.fc_output.bias[i_replace] = 0

    def reset_aux_weights(self, i_replace):
        for i in i_replace:
            nn.init.xavier_uniform_(self.fc_output.weight.data[self.num_actions * (i+1): self.num_actions * (i+2), :])

    def get_shared_gradient(self):
        fc1_w_grad = self.fc_hidden_1.weight.grad.flatten()
        fc1_b_grad = self.fc_hidden_1.bias.grad.flatten()
        grads = torch.cat((fc1_w_grad, fc1_b_grad))
        return grads

    def get_shared_weight_size(self):
        output = self.input_size * self.hidden_size + self.hidden_size
        return output

    def get_stable_rank(self):
        mat = self.fc_hidden_1.weight.data.cpu().detach().numpy()
        stable_rank = np.linalg.norm(mat) ** 2 / np.linalg.norm(mat, 2) ** 2
        return stable_rank