import torch.nn as nn
import numpy as np
import torch
from torch.autograd import Variable


class CNN(nn.Module):
    def __init__(self, config):
        super(CNN, self).__init__()

        self.in_channels = 3
        self.input_size = int(np.sqrt(config['input_size']/3))
        self.output_size = config['output_size']
        self.hidden_size = config['hidden_size']


        # self.image_conv = nn.Sequential(
        #     nn.Conv2d(3, 4, (2, 2)),
        #     nn.ReLU(),
        #     nn.MaxPool2d((2, 2)),
        #     nn.Conv2d(4, 8, (2, 2)),
        #     nn.ReLU(),
        #     nn.Conv2d(8, 16, (2, 2)),
        #     nn.ReLU()
        # )
        # n = self.input_size[0]
        # num_linear_units = ((n - 1) // 2 - 2) * ((n - 1) // 2 - 2) * 16
        # self.fc_hidden = nn.Linear(in_features=num_linear_units, out_features=self.hidden_size)
        # self.output = nn.Linear(in_features=self.hidden_size, out_features=self.output_size)


        self.image_conv = nn.Conv2d(3, 8, kernel_size=3, stride=1)
        n = self.input_size
        def size_linear_unit(size, kernel_size=3, stride=1):
            return (size - (kernel_size - 1) - 1) // stride + 1
        num_linear_units = size_linear_unit(n) * size_linear_unit(n) * 8
        self.fc_hidden = nn.Linear(in_features=num_linear_units, out_features=self.hidden_size)
        self.output = nn.Linear(in_features=self.hidden_size, out_features=self.output_size)


    def forward(self, input):
        # x = self.to_variable(input)
        # x = self.image_conv(x)
        # x = x.reshape(x.shape[0], -1)
        # x = nn.functional.tanh(self.fc_hidden(x))

        x = self.to_variable(input)
        x = nn.functional.relu(self.image_conv(x))
        x = nn.functional.tanh(self.fc_hidden(x.view(x.size(0), -1)))

        return self.output(x)

    def get_features_binary(self, input):
        x = self.to_variable(input)
        # Rectified output from the first conv layer
        x = nn.functional.relu(self.image_conv(x))
        # Rectified output from the final hidden layer
        x = nn.functional.relu(self.fc_hidden(x.view(x.size(0), -1)))
        x = nn.functional.sigmoid(x)
        x = x.cpu().detach().numpy().flatten()
        x = np.concatenate((x, np.array([1])))
        return x

    def get_features(self, input):
        x = self.to_variable(input)
        # Rectified output from the first conv layer
        x = nn.functional.relu(self.image_conv(x))
        # Rectified output from the final hidden layer
        x = nn.functional.relu(self.fc_hidden(x.view(x.size(0), -1)))
        x = x.cpu().detach().numpy().flatten()
        x = np.concatenate((x, np.array([1])))
        return x

    def to_variable(self, input, dtype=torch.FloatTensor):
        if isinstance(input, Variable):
            return input
        output = input
        if len(output.shape) == self.in_channels:
            output = np.expand_dims(output, 0)
        output = output.swapaxes(1, 3).swapaxes(2, 3)
        output = self.adjust_type(output, dtype)
        output = dtype(torch.from_numpy(output))
        return Variable(output)


    def adjust_type(self, input, torch_type):
        if torch_type == torch.FloatTensor:
            return np.asarray(input, dtype=np.float32)
        if torch_type == torch.LongTensor:
            return np.asarray(input, dtype=np.int64)
        print('error!')


    def reset_output_weights(self, i_replace):
        with torch.no_grad():
            std = torch.sqrt(torch.tensor(2.0/(self.hidden_size + self.output_size)))
            self.output.weight[i_replace, :] = torch.normal(0, std, size=(i_replace.shape[0], self.hidden_size))
            self.output.bias[i_replace] = 0

    def get_shared_gradient(self):
        image_conv_w_grad = self.image_conv.weight.grad.flatten()
        image_conv_b_grad = self.image_conv.bias.grad.flatten()
        fc_w_grad = self.fc_hidden.weight.grad.flatten()
        fc_b_grad = self.fc_hidden.bias.grad.flatten()
        grads = torch.cat((image_conv_w_grad, image_conv_b_grad, fc_w_grad, fc_b_grad))
        return grads

    def get_shared_weight_size(self):
        size1 = self.image_conv.weight.flatten().shape[0]
        size2 = self.image_conv.bias.flatten().shape[0]
        size3 = self.fc_hidden.weight.flatten().shape[0]
        size4 = self.fc_hidden.bias.flatten().shape[0]
        return size1+size2+size3+size4


    def reset_aux_weights(self, i_replace):
        return
        # for i in i_replace:
        #     nn.init.xavier_uniform_(self.fc_output.weight.data[self.num_actions * (i+1): self.num_actions * (i+2), :])


