import gym
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.distributions import Categorical
from torch.utils.data import Dataset

"""
    Softmax tabular with phi(s,a) = v_a^T s, which is linear phi.
    Input: a numpy list of states s, which will be transformed into Variable
    Output: a probability vector pi(.|s)
"""
class PolicyNet(nn.Module):
    def __init__(self, state_dim, n_action, eps=0):
        super(PolicyNet, self).__init__()
        self.n_state = state_dim
        self.n_action = n_action
        self.eps = eps
        self.fc1 = nn.Linear(state_dim, n_action)

    def forward(self, x):
        x = torch.softmax(self.fc1(x), dim=1)
        return x
        
    def get_pi_vec(self, states):
        # Input: states is a [b, s_dim] numpy array
        # Output: a [b, nA] tensor
        b = len(states)
        with torch.no_grad():
            probs = self.__call__(Variable(torch.from_numpy(states).float())) # a [b, nA] tensor
            for i in range(b):
                probs[i] = (1-self.eps)*probs[i] + self.eps/self.n_action
        return probs
        
    def get_action(self, state, eps=0, sample=True):
        # get action for ONE state, used in PG
        with torch.no_grad():
            probs = self.__call__(Variable(torch.from_numpy(np.reshape(state, [1, self.n_state])).float()))
            probs = (1-eps)*probs + eps/self.n_action
        if sample:
            m = Categorical(probs)
            action = m.sample()
            return action.data.numpy().astype(int)[0]
        else:
            return probs[0].numpy()

# Same architecture as https://github.com/pytorch/tutorials/blob/master/intermediate_source/reinforcement_q_learning.py
class QVNet(nn.Module):
    def __init__(self, h, w, outputs):
        super(QVNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
        self.bn3 = nn.BatchNorm2d(32)

        # Number of Linear input connections depends on output of conv2d layers
        # and therefore the input image size, so compute it.
        def conv2d_size_out(size, kernel_size = 5, stride = 2):
            return (size - (kernel_size - 1) - 1) // stride  + 1
        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
        linear_input_size = convw * convh * 32
        self.head = nn.Linear(linear_input_size, outputs)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        return self.head(x.view(x.size(0), -1))
        
    def get_integral_vecs(self, states):
        # Input: states is a [b, s_dim] numpy array
        # Output: a [b, nA] tensor
        with torch.no_grad():
            result = self.__call__(Variable(torch.from_numpy(states).float())) # a [b, nA] tensor
        return result
        
class QVRNet(nn.Module):
    def __init__(self, h, w, outputs):
        super(QVRNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=3, stride=1)
        self.bn3 = nn.BatchNorm2d(32)

        # Number of Linear input connections depends on output of conv2d layers
        # and therefore the input image size, so compute it.
        def conv2d_size_out(size, kernel_size = 5, stride = 2):
            return (size - (kernel_size - 1) - 1) // stride  + 1
        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(w)))
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(h)))
        linear_input_size = convw * convh * 32
        self.head = nn.Linear(linear_input_size, outputs)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        return self.head(x.view(x.size(0), -1))
        
    def get_integral_vecs(self, states):
        # Input: states is a [b, s_dim] numpy array
        # Output: a [b, nA] tensor
        with torch.no_grad():
            result = self.__call__(Variable(torch.from_numpy(states).float())) # a [b, nA] tensor
        return result

def get_flat_grads_from(model):
    grads = []
    for param in model.parameters():
        grads.append(param.grad.data.view(-1))
    flat_grads = torch.cat(grads)
    return flat_grads.numpy()
    
def net_param_num(model):
    result = 0
    for param in model.parameters():
        if param.requires_grad:
            result += param.numel()
    return result

class CustomDataset(Dataset):
    # Output will be turned into torch tensor by DataLoader later
    def __init__(self, samples, actions, labels):
        super().__init__()
        self.samples = torch.tensor(samples)
        self.actions = torch.tensor(actions)
        self.labels = torch.tensor(labels)
        
    def __len__(self):
        return len(self.samples)
        
    def __getitem__(self, idx):
        return self.samples[idx].float(), self.actions[idx], self.labels[idx].float()