import torch.nn as nn
import numpy as np
import torch
from torch.autograd import Variable


class MasterSlaveCNN(nn.Module):
    def __init__(self, config):
        super(MasterSlaveCNN, self).__init__()
        self.feature_main_task_direct = True

        self.in_channels = 3
        self.input_size = int(np.sqrt(config['input_size']/3))
        self.num_actions = config['num_actions']
        self.num_aux_tasks = config['num_aux_tasks']
        self.hidden_size = config['hidden_size']

        n = self.input_size
        def size_linear_unit(size, kernel_size=3, stride=1):
            return (size - (kernel_size - 1) - 1) // stride + 1
        num_linear_units = size_linear_unit(n) * size_linear_unit(n) * 8

        self.feature_per_aux = int(self.hidden_size / (self.num_aux_tasks+1))

        self.input_conv = []
        self.conv_hidden = []
        self.hidden_output = []

        for i in np.arange(self.num_aux_tasks+1):
            self.input_conv.append(nn.Conv2d(3, 8, kernel_size=3, stride=1))
            self.conv_hidden.append(nn.Linear(in_features=num_linear_units, out_features=self.feature_per_aux))
            self.hidden_output.append(nn.Linear(self.hidden_size, self.num_actions))

        self.input_conv = nn.ModuleList(self.input_conv)
        self.conv_hidden = nn.ModuleList(self.conv_hidden)
        self.hidden_output = nn.ModuleList(self.hidden_output)

    def forward(self, features, i):
        output = self.hidden_output[i](features)
        return output


    def get_features(self, input):
        x = self.to_variable(input)
        hidden_list = []
        for i in np.arange(self.num_aux_tasks+1):
            hidden_list.append(torch.tanh(self.conv_hidden[i](torch.relu(self.input_conv[i](x)).view(x.size(0), -1))))
        return torch.cat(hidden_list, dim=1)


    def to_variable(self, input, dtype=torch.FloatTensor):
        if isinstance(input, Variable):
            return input
        output = input
        if len(output.shape) == self.in_channels:
            output = np.expand_dims(output, 0)
        output = output.swapaxes(1, 3).swapaxes(2, 3)
        output = self.adjust_type(output, dtype)
        output = dtype(torch.from_numpy(output))
        return Variable(output)

    def adjust_type(self, input, torch_type):
        if torch_type == torch.FloatTensor:
            return np.asarray(input, dtype=np.float32)
        if torch_type == torch.LongTensor:
            return np.asarray(input, dtype=np.int64)
        print('error!')

    def reset_aux_weights(self, i_replace):
        for i in i_replace:
            nn.init.xavier_uniform_(self.input_conv[i].weight.data)
            nn.init.xavier_uniform_(self.conv_hidden[i].weight.data)
            nn.init.xavier_uniform_(self.hidden_output.weight.data)
            for o in np.arange(self.num_aux_tasks+1):
                nn.init.xavier_uniform_(self.hidden_output[o].weight.data[:, (i+1) * self.feature_per_aux: (i+2) * self.feature_per_aux])


    def zero_grad_preserved_features(self, preserved_features):
        return
