import torch
import functools
import inspect    
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

from utils.snippets import get_same_padding_size, conv_output_shape

def get_state(s):
    return (torch.tensor(s, device=device).permute(2, 0, 1)).unsqueeze(0).float()


class Flatten(torch.nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)

class PreprocessAtari(nn.Module):
    def forward(self, x):
        x = x.permute(0, 3, 1, 2).contiguous()
        return x / 255.

def create_atari_q_network(ob_dim, num_actions):
    return nn.Sequential(
        PreprocessAtari(),
        nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4),
        nn.ReLU(),
        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
        nn.ReLU(),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
        nn.ReLU(),
        Flatten(),
        nn.Linear(3136, 512),  # 3136 hard-coded based on img size + CNN layers
        nn.ReLU(),
        nn.Linear(512, num_actions),
    )

class AtariModel(nn.Module): # NOTE: this architecture can further be hanced for different problems. However, consider that I have previously used this architecture for Pong game and it worked properly.

    def __init__(self, img_input, num_actions):
        super(AtariModel, self).__init__()

        # Using same padding
        padding_size = get_same_padding_size(kernel_size=8, stride=4)
        self.conv1 = nn.Conv2d(4, 32, 8, stride=4, padding=padding_size)
        nn.init.xavier_uniform_(self.conv1.weight)
        nn.init.zeros_(self.conv1.bias)
        
        out_size = conv_output_shape(img_input, 8, 4, padding_size)
        padding_size = get_same_padding_size(kernel_size=4, stride=2)
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2, padding=padding_size)
        nn.init.xavier_uniform_(self.conv2.weight)
        nn.init.zeros_(self.conv2.bias)
        
        out_size = conv_output_shape(out_size, 4, 2, padding_size)
        padding_size = get_same_padding_size(kernel_size=3, stride=1)
        self.conv3 = nn.Conv2d(64, 64, 3, stride=1, padding=padding_size)
        nn.init.xavier_uniform_(self.conv3.weight)
        nn.init.zeros_(self.conv3.bias)

        self.conv_out_size = np.prod(conv_output_shape(out_size, 3, 1, padding_size)) * 64
        
        self.fc1 = nn.Linear(self.conv_out_size, 512)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        
        self.fc2 = nn.Linear(512, num_actions)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(-1, self.conv_out_size)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class MLP(nn.Module):
    def __init__(self, input_size, output_size, n_layers, 
            size, activation=torch.tanh, output_activation=None):
        super(MLP, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.activation = activation
        self.size = size
        self.n_layers = n_layers
        self.output_activation = output_activation
        
        layers_size = [self.input_size] + ([self.size]*self.n_layers) + [self.output_size]
        self.layers = nn.ModuleList([nn.Linear(layers_size[i], layers_size[i+1]) 
                                    for i in range(len(layers_size)-1)])
        
        for layer in self.layers:
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)
        

    def forward(self, x):
        
        out = x
        for i, layer in enumerate(self.layers):
            if i!=len(self.layers)-1:
                out = self.activation(layer(out))
            else:
                out = layer(out)

        if self.output_activation is not None:
            out = self.output_activation(out)
       
        return out

################################################################################################################
# class QNetwork
#
# One hidden 2D conv with variable number of input channels.  We use 16 filters, a quarter of the original DQN
# paper of 64.  One hidden fully connected linear layer with a quarter of the original DQN paper of 512
# rectified units.  Finally, the output layer is a fully connected linear layer with a single output for each
# valid action.
#
################################################################################################################
class MinAtarQNetwork(nn.Module):
    def __init__(self, input_shape, num_actions, in_channels):

        super(MinAtarQNetwork, self).__init__()
        
        self.preprocess = PreprocessAtari()
        # One hidden 2D convolution layer:
        #   in_channels: variable
        #   out_channels: 16
        #   kernel_size: 3 of a 3x3 filter matrix
        #   stride: 1
        self.conv = nn.Conv2d(in_channels, 16, kernel_size=3, stride=1)

        # Final fully connected hidden layer:
        #   the number of linear unit depends on the output of the conv
        #   the output consist 128 rectified units
        def size_linear_unit(size, kernel_size=3, stride=1):
            return (size - (kernel_size - 1) - 1) // stride + 1
        num_linear_units = size_linear_unit(input_shape[0]) * size_linear_unit(input_shape[1]) * 16
        self.fc_hidden = nn.Linear(in_features=num_linear_units, out_features=128)

        # Output layer:
        self.output = nn.Linear(in_features=128, out_features=num_actions)

    # As per implementation instructions according to pytorch, the forward function should be overwritten by all
    # subclasses
    def forward(self, x):

        x = self.preprocess(x)
        # Rectified output from the first conv layer
        x = F.relu(self.conv(x))

        # Rectified output from the final hidden layer
        x = F.relu(self.fc_hidden(x.view(x.size(0), -1)))

        # Returns the output from the fully-connected linear layer
        return self.output(x)

class LFANet(nn.Module):
    def __init__(self, input_size, output_size, n_layers, 
            size, activation=None, output_activation=None):
        super(LFANet, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.size = size
        self.n_layers = n_layers
        
        layers_size = [self.input_size] + ([self.size]*self.n_layers) + [self.output_size]
        self.layers = nn.ModuleList([nn.Linear(layers_size[i], layers_size[i+1]) 
                                    for i in range(len(layers_size)-1)])
        
        for layer in self.layers:
            nn.init.xavier_uniform_(layer.weight)
            nn.init.zeros_(layer.bias)
        

    def forward(self, x):
        
        out = x
        for i, layer in enumerate(self.layers):
            out = layer(out)
       
        return out
