import math
import random

import torch
import torch.nn as nn
import torch.nn.functional as F


def DQN_3D(env, args):
    if args.c51:
        if args.dueling:
            if args.env == '3DStatic':
                model = CategoricalDuelingDQN(args.env, args.noisy, args.sigma_init,
                                              args.Vmin, args.Vmax, args.num_atoms, args.batch_size).to(args.device)
            elif args.env == '3DDynamic':
                model = CategoricalDuelingDQN_Dynamic(args.env, args.noisy, args.sigma_init,
                                                      args.Vmin, args.Vmax, args.num_atoms, args.batch_size).to(
                    args.device)
        else:
            model = CategoricalDQN(args.env, args.noisy, args.sigma_init,
                                   args.Vmin, args.Vmax, args.num_atoms, args.batch_size).to(args.device)
    else:
        if args.dueling:
            model = DuelingDQN(args.env, args.noisy, args.sigma_init).to(args.device)
        else:
            model = DQNBase(args.env, args.noisy, args.sigma_init).to(args.device)

    return model


class DQNBase(nn.Module):
    """
    Basic DQN + NoisyNet

    Noisy Networks for Exploration
    https://arxiv.org/abs/1706.10295

    parameters
    ---------
    env         environment(openai gym)
    noisy       boolean value for NoisyNet.
                If this is set to True, self.Linear will be NoisyLinear module
    """

    def __init__(self, env, noisy, sigma_init):
        super(DQNBase, self).__init__()

        self.env = env
        self.num_actions = 8
        self.noisy = noisy

        self.flatten = Flatten()

        self.plan_features = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1),
            nn.LeakyReLU(),
            nn.BatchNorm2d(32),

            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1),
            nn.LeakyReLU(),
            nn.MaxPool2d(2),
            nn.BatchNorm2d(64),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1),
            nn.LeakyReLU(),
            nn.MaxPool2d(2),
            nn.BatchNorm2d(128),

            nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1)
        )

        self.features = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
            nn.LeakyReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
            nn.LeakyReLU()
        )

        self.fc = nn.Sequential(
            NoisyLinear(self._feature_size(), 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, self.num_actions, sigma_init)
        )

    def forward(self, x):
        x = x[:, :, :x[0][0].shape[0] - 2].view(-1, 1, 7, 7)

        x = self.features(x)
        x = self.flatten(x)
        return self.fc(x)

    def _feature_size(self):
        if self.env == '3DStatic':
            return 65
        else:
            return 129

    def act(self, state, epsilon):
        """
        Parameters
        ----------
        state       torch.Tensor with appropritate device type
        epsilon     epsilon for epsilon-greedy
        """
        if random.random() > epsilon or self.noisy:  # NoisyNet does not use e-greedy
            with torch.no_grad():
                state = state.unsqueeze(0)
                q_value = self.forward(state)
                action = q_value.max(1)[1].item()
        else:
            action = random.randrange(self.num_actions)
        return action

    def update_noisy_modules(self):
        if self.noisy:
            self.noisy_modules = [module for module in self.modules() if isinstance(module, NoisyLinear)]

    def reset_parameters(self):
        for module in self.noisy_modules:
            module.reset_parameters()

    def sample_noise(self):
        for module in self.noisy_modules:
            module.sample_noise()

    def remove_noise(self):
        for module in self.noisy_modules:
            module.remove_noise()

class DuelingDQN(DQNBase):
    """
    Dueling Network Architectures for Deep Reinforcement Learning
    https://arxiv.org/abs/1511.06581
    """

    def __init__(self, env, noisy, sigma_init):
        super(DuelingDQN, self).__init__(env, noisy, sigma_init)

        self.advantage = self.fc

        self.value = nn.Sequential(
            NoisyLinear(self._feature_size(), 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, 1, sigma_init)
        )

    def forward(self, x):
        x = x[:, :, :x[0][0].shape[0] - 2].view(-1, 1, 7, 7)
        x = self.features(x)
        x = self.flatten(x)
        advantage = self.advantage(x)
        value = self.value(x)
        return value + advantage - advantage.mean(1, keepdim=True)


class CategoricalDQN(DQNBase):
    """
    A Distributional Perspective on Reinforcement Learning
    https://arxiv.org/abs/1707.06887
    """

    def __init__(self, env, noisy, sigma_init, Vmin, Vmax, num_atoms, batch_size):
        super(CategoricalDQN, self).__init__(env, noisy, sigma_init)

        support = torch.linspace(Vmin, Vmax, num_atoms)
        offset = torch.linspace(0, (batch_size - 1) * num_atoms, batch_size).long() \
            .unsqueeze(1).expand(batch_size, num_atoms).clone()

        self.register_buffer('support', support)
        self.register_buffer('offset', offset)
        self.num_atoms = num_atoms

        self.fc = nn.Sequential(
            NoisyLinear(self._feature_size(), 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, self.num_actions * self.num_atoms, sigma_init)
        )

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = x[:, :, :x[0][0].shape[0] - 2].view(-1, 1, 7, 7)
        x = self.features(x)
        x = self.flatten(x)
        x = self.fc(x)
        x = self.softmax(x.view(-1, self.num_atoms))
        x = x.view(-1, self.num_actions, self.num_atoms)
        return x

    def act(self, state, epsilon):
        """
        Parameters
        ----------
        state       torch.Tensor with appropritate device type
        epsilon     epsilon for epsilon-greedy
        """
        if random.random() > epsilon or self.noisy:  # NoisyNet does not use e-greedy
            with torch.no_grad():
                state = state.unsqueeze(0)
                q_dist = self.forward(state)
                q_value = (q_dist * self.support).sum(2)
                action = q_value.max(1)[1].item()
        else:
            action = random.randrange(self.num_actions)
        return action


class CategoricalDuelingDQN(CategoricalDQN):

    def __init__(self, env, noisy, sigma_init, Vmin, Vmax, num_atoms, batch_size):
        super(CategoricalDuelingDQN, self).__init__(env, noisy, sigma_init, Vmin, Vmax, num_atoms, batch_size)

        self.advantage = self.fc

        self.value = nn.Sequential(
            NoisyLinear(self._feature_size(), 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, self.num_atoms, sigma_init)
        )

    def forward(self, x):
        input = x
        x = x[:, :, :x[0][0].shape[0] - 2].view(-1, 1, 7, 7)
        x = self.features(x)
        x = self.flatten(x)
        x = torch.cat([x, input[:, :, -2:-1].view(-1, 1)], axis=1)

        advantage = self.advantage(x).view(-1, self.num_actions, self.num_atoms)
        value = self.value(x).view(-1, 1, self.num_atoms)

        x = value + advantage - advantage.mean(1, keepdim=True)
        x = self.softmax(x.view(-1, self.num_atoms))
        x = x.view(-1, self.num_actions, self.num_atoms)
        return x


class CategoricalDQN_Dynamic(DQNBase):
    """
    A Distributional Perspective on Reinforcement Learning
    https://arxiv.org/abs/1707.06887
    """

    def __init__(self, env, noisy, sigma_init, Vmin, Vmax, num_atoms, batch_size):
        super(CategoricalDQN_Dynamic, self).__init__(env, noisy, sigma_init)

        support = torch.linspace(Vmin, Vmax, num_atoms)
        offset = torch.linspace(0, (batch_size - 1) * num_atoms, batch_size).long() \
            .unsqueeze(1).expand(batch_size, num_atoms).clone()

        self.register_buffer('support', support)
        self.register_buffer('offset', offset)
        self.num_atoms = num_atoms

        self.fc = nn.Sequential(
            NoisyLinear(self._feature_size(), 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, self.num_actions * self.num_atoms, sigma_init)
        )

        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = x[:, :, :x[0][0].shape[0] - 2].view(-1, 1, 7, 7)
        x = self.features(x)
        x = self.flatten(x)
        x = self.fc(x)
        x = self.softmax(x.view(-1, self.num_atoms))
        x = x.view(-1, self.num_actions, self.num_atoms)
        return x

    def act(self, state, epsilon):
        """
        Parameters
        ----------
        state       torch.Tensor with appropritate device type
        epsilon     epsilon for epsilon-greedy
        """
        if random.random() > epsilon or self.noisy:  # NoisyNet does not use e-greedy
            with torch.no_grad():
                # print("UNSQUEEZED STATE SIZE: ", state.shape)
                # state = state.unsqueeze(0)
                # print("SQUEEZED STATE SIZE: ", state.shape)
                q_dist = self.forward(state)
                q_value = (q_dist * self.support).sum(2)
                action = q_value.max(1)[1].item()
        else:
            action = random.randrange(self.num_actions)
        return action


class CategoricalDuelingDQN_Dynamic(CategoricalDQN_Dynamic):

    def __init__(self, env, noisy, sigma_init, Vmin, Vmax, num_atoms, batch_size):
        super(CategoricalDuelingDQN_Dynamic, self).__init__(env, noisy, sigma_init, Vmin, Vmax, num_atoms, batch_size)

        self.advantage = self.fc

        self.value = nn.Sequential(
            NoisyLinear(self._feature_size(), 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, 128, sigma_init),
            nn.ReLU(),
            NoisyLinear(128, self.num_atoms, sigma_init)
        )

    def forward(self, x):
        x = x.view(-1, 1, 451)
        input = x
        plan = input[:, :, 51:].view(-1, 1, 20, 20)
        x = x[:, :, :49].view(-1, 1, 7, 7)

        plan = self.plan_features(plan)
        plan = self.flatten(plan)

        x = self.features(x)
        x = x.view(-1, 64)
        x = torch.cat([x, input[:, :, -2:-1].view(-1, 1), plan.view(-1, 64)], axis=1)

        advantage = self.advantage(x).view(-1, self.num_actions, self.num_atoms)
        value = self.value(x).view(-1, 1, self.num_atoms)

        x = value + advantage - advantage.mean(1, keepdim=True)
        x = self.softmax(x.view(-1, self.num_atoms))
        x = x.view(-1, self.num_actions, self.num_atoms)
        return x


class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size(0), -1)


class NoisyLinear(nn.Module):
    def __init__(self, in_features, out_features, sigma_init):
        super(NoisyLinear, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.sigma_init = sigma_init

        self.weight_mu = nn.Parameter(torch.FloatTensor(out_features, in_features))
        self.weight_sigma = nn.Parameter(torch.FloatTensor(out_features, in_features))
        self.register_buffer('weight_epsilon', torch.FloatTensor(out_features, in_features))

        self.bias_mu = nn.Parameter(torch.FloatTensor(out_features))
        self.bias_sigma = nn.Parameter(torch.FloatTensor(out_features))
        self.register_buffer('bias_epsilon', torch.FloatTensor(out_features))

        self.register_buffer('sample_weight_in', torch.FloatTensor(in_features))
        self.register_buffer('sample_weight_out', torch.FloatTensor(out_features))
        self.register_buffer('sample_bias_out', torch.FloatTensor(out_features))

        self.reset_parameters()
        self.sample_noise()

    def forward(self, x):
        if self.training:
            weight = self.weight_mu + self.weight_sigma.mul(self.weight_epsilon)
            bias = self.bias_mu + self.bias_sigma.mul(self.bias_epsilon)
        else:
            weight = self.weight_mu
            bias = self.bias_mu

        return F.linear(x, weight, bias)

    def reset_parameters(self):
        mu_range = 1 / math.sqrt(self.weight_mu.size(1))

        self.weight_mu.data.uniform_(-mu_range, mu_range)
        self.weight_sigma.data.fill_(self.sigma_init / math.sqrt(self.weight_sigma.size(1)))

        self.bias_mu.data.uniform_(-mu_range, mu_range)
        self.bias_sigma.data.fill_(self.sigma_init / math.sqrt(self.bias_sigma.size(0)))

    def sample_noise(self):
        self.sample_weight_in = self._scale_noise(self.sample_weight_in)
        self.sample_weight_out = self._scale_noise(self.sample_weight_out)
        self.sample_bias_out = self._scale_noise(self.sample_bias_out)

        self.weight_epsilon.copy_(self.sample_weight_out.ger(self.sample_weight_in))
        self.bias_epsilon.copy_(self.sample_bias_out)

    def _scale_noise(self, x):
        x = x.normal_()
        x = x.sign().mul(x.abs().sqrt())
        return x
