import cv2
cv2.ocl.setUseOpenCL(False)

import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F

import numpy as np

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ALGO LOGIC: initialize agent here:
class Scale(nn.Module):
    def __init__(self, scale):
        super().__init__()
        self.scale = scale

    def forward(self, x):
        return x * self.scale

def layer_init(layer, std=np.sqrt(2), bias_const=0.0):
    torch.nn.init.orthogonal_(layer.weight, std)
    torch.nn.init.constant_(layer.bias, bias_const)
    return layer

class AgentSharedNetwork(nn.Module):
    def __init__(self, envs, frames=4):
        super(AgentSharedNetwork, self).__init__()
        self.network = nn.Sequential(
            Scale(1/255),
            layer_init(nn.Conv2d(frames, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(3136, 512)),
            nn.ReLU()
        )
        self.actor = layer_init(nn.Linear(512, envs.action_space.n), std=0.01)
        self.critic = layer_init(nn.Linear(512, 1), std=1)

    def forward(self, x):
        return self.network(x)

    def get_action(self, x, action=None):
        logits = self.actor(self.forward(x))
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy()

    def get_value(self, x):
        return self.critic(self.forward(x))

    def get_prob(self, x):
        logits = self.actor(self.forward(x))
        probs = Categorical(logits=logits)
        return probs.probs

class Actor(nn.Module):
    def __init__(self, envs, frames=4):
        super(Actor, self).__init__()
        self.actor = nn.Sequential(
            Scale(1/255),
            layer_init(nn.Conv2d(frames, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(3136, 512)),
            nn.ReLU(),
            layer_init(nn.Linear(512, envs.action_space.n), std=0.01)
        )

    def forward(self, x):
        assert False
        # return self.network(x)

    def get_action(self, x, action=None):
        logits = self.actor(x)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy()

    def get_prob(self, x):
        logits = self.actor(x)
        probs = Categorical(logits=logits)
        return probs.probs

class Critic(nn.Module):
    def __init__(self, frames=4):
        super(Critic, self).__init__()
        self.critic = nn.Sequential(
            Scale(1/255),
            layer_init(nn.Conv2d(frames, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(3136, 512)),
            nn.ReLU(),
            layer_init(nn.Linear(512, 1), std=1)
        )

    def forward(self, x):
        assert False
        # return self.network(x)

    def get_value(self, x, weights = None):
        if weights is None:
            return self.critic(x)
        else:
            x = x * (1/255)
            x = F.conv2d(x, weights['critic.1.weight'].cuda(), weights['critic.1.bias'].cuda(), 4)
            x = F.threshold(x, 0, 0, inplace=True)
            x = F.conv2d(x, weights['critic.3.weight'].cuda(), weights['critic.3.bias'].cuda(), 2)
            x = F.threshold(x, 0, 0, inplace=True)
            x = F.conv2d(x, weights['critic.5.weight'].cuda(), weights['critic.5.bias'].cuda(), 1)
            x = F.threshold(x, 0, 0, inplace=True)
            x = x.reshape(x.shape[0], -1)
            features = F.linear(x, weights['critic.8.weight'].cuda(), weights['critic.8.bias'].cuda())
            features = F.threshold(features, 0, 0, inplace=True)
            value = F.linear(features, weights['critic.10.weight'].cuda(), weights['critic.10.bias'].cuda())
            return value

class AgentTwoNetworks(nn.Module):
    def __init__(self, envs, frames=4):
        super(AgentTwoNetworks, self).__init__()
        self.actor = Actor(envs, frames)
        self.critic = Critic(frames)

    def forward(self, x):
        assert False
        # return self.network(x)

    def get_action(self, x, action=None):
        return self.actor.get_action(x, action)

    def get_value(self, x):
        return self.critic.get_value(x)

    def get_prob(self, x):
        return self.actor.get_prob(x)

class ActorHead(nn.Module):
    def __init__(self, envs, frames=4):
        super(ActorHead, self).__init__()
        self.features = nn.Sequential(
            Scale(1/255),
            layer_init(nn.Conv2d(frames, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(3136, 512)),
            nn.ReLU()
        )
        self.head = layer_init(nn.Linear(512, envs.action_space.n), std=0.01)

    def forward(self, x):
        assert False
        # return self.network(x)

    def get_action(self, x, action=None, deterministic=False, weights = None, head_only_fast_weights = False):
        if weights is None:
            features = self.features(x)
            logits = self.head(features)
            probs = Categorical(logits=logits)
            if action is None:
                if deterministic:
                    action = probs.probs.argmax(dim=-1, keepdim=False)
                else:
                    action = probs.sample()
            return action, probs.log_prob(action), probs.entropy()
        else: #TODO if head only, features=self.features(x)
            x = x * (1/255)
            x = F.conv2d(x, weights['features.1.weight'].cuda(), weights['features.1.bias'].cuda(), 4)
            x = F.threshold(x, 0, 0, inplace=True)
            x = F.conv2d(x, weights['features.3.weight'].cuda(), weights['features.3.bias'].cuda(), 2)
            x = F.threshold(x, 0, 0, inplace=True)
            x = F.conv2d(x, weights['features.5.weight'].cuda(), weights['features.5.bias'].cuda(), 1)
            x = F.threshold(x, 0, 0, inplace=True)
            x = x.reshape(x.shape[0],-1)
            features = F.linear(x, weights['features.8.weight'].cuda(), weights['features.8.bias'].cuda())
            features = F.threshold(features, 0, 0, inplace=True)
            logits = F.linear(features, weights['head.weight'].cuda(), weights['head.bias'].cuda())
            probs = Categorical(logits=logits)
            if action is None:
                if deterministic:
                    action = probs.probs.argmax(dim=-1, keepdim=False)
                else:
                    action = probs.sample()
            return action, probs.log_prob(action), probs.entropy()

    def get_prob(self, x):
        features = self.features(x)
        logits = self.head(features)
        probs = Categorical(logits=logits)
        return probs.probs

class AgentTwoNetworksHead(nn.Module):
    def __init__(self, envs, frames=4):
        super(AgentTwoNetworksHead, self).__init__()
        self.actor = ActorHead(envs, frames)
        self.critic = Critic(frames)

    def forward(self, x):
        assert False
        # return self.network(x)

    def get_actor_features(self, x):
        return self.actor.features(x)

    def get_action(self, x, action=None, deterministic = False, weights = None):
        return self.actor.get_action(x, action, deterministic, weights)

    def get_value(self, x, weights = None):
        return self.critic.get_value(x, weights)

    def get_prob(self, x):
        return self.actor.get_prob(x)

class RNDModel(nn.Module):
    def __init__(self, envs, frames=1):
        super(RNDModel, self).__init__()

        input_size = envs.observation_space.shape  # 4
        output_size = envs.action_space.n  # 2

        # self.input_size = input_size
        # self.output_size = output_size

        self.predictor = nn.Sequential(
            nn.Conv2d(
                in_channels=frames,
                out_channels=32,
                kernel_size=8,
                stride=4),
            nn.LeakyReLU(),
            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=4,
                stride=2),
            nn.LeakyReLU(),
            nn.Conv2d(
                in_channels=64,
                out_channels=64,
                kernel_size=3,
                stride=1),
            nn.LeakyReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512)
        )

        self.target = nn.Sequential(
            nn.Conv2d(
                in_channels=frames,
                out_channels=32,
                kernel_size=8,
                stride=4),
            nn.LeakyReLU(),
            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=4,
                stride=2),
            nn.LeakyReLU(),
            nn.Conv2d(
                in_channels=64,
                out_channels=64,
                kernel_size=3,
                stride=1),
            nn.LeakyReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512)
        )

        for p in self.modules():
            if isinstance(p, nn.Conv2d):
                nn.init.orthogonal_(p.weight, np.sqrt(2))
                p.bias.data.zero_()

            if isinstance(p, nn.Linear):
                nn.init.orthogonal_(p.weight, np.sqrt(2))
                p.bias.data.zero_()

        for param in self.target.parameters():
            param.requires_grad = False

    def forward(self, next_obs):
        target_feature = self.target(next_obs)
        predict_feature = self.predictor(next_obs)

        return predict_feature, target_feature

class ICMModel(nn.Module):
    def __init__(self, envs, frames=4, use_cuda=True):
        super(ICMModel, self).__init__()

        input_size = envs.observation_space.shape  # 4
        output_size = envs.action_space.n  # 2

        self.input_size = input_size
        self.output_size = output_size
        device = torch.device('cuda' if use_cuda else 'cpu')

        feature_output = 7 * 7 * 64
        self.feature = nn.Sequential(
            nn.Conv2d(
                in_channels=frames,
                out_channels=32,
                kernel_size=8,
                stride=4),
            nn.LeakyReLU(),
            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=4,
                stride=2),
            nn.LeakyReLU(),
            nn.Conv2d(
                in_channels=64,
                out_channels=64,
                kernel_size=3,
                stride=1),
            nn.LeakyReLU(),
            nn.Flatten(),
            nn.Linear(feature_output, 512)
        )

        self.inverse_net = nn.Sequential(
            nn.Linear(512 * 2, 512),
            nn.ReLU(),
            nn.Linear(512, output_size)
        )

        self.residual = [nn.Sequential(
            nn.Linear(output_size + 512, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 512),
        ).to(device)] * 8

        self.forward_net_1 = nn.Sequential(
            nn.Linear(output_size + 512, 512),
            nn.LeakyReLU()
        )
        self.forward_net_2 = nn.Sequential(
            nn.Linear(output_size + 512, 512),
        )

        for p in self.modules():
            if isinstance(p, nn.Conv2d):
                nn.init.kaiming_uniform_(p.weight)
                p.bias.data.zero_()

            if isinstance(p, nn.Linear):
                nn.init.kaiming_uniform_(p.weight, a=1.0)
                p.bias.data.zero_()

    def forward(self, inputs):
        state, next_state, action = inputs

        encode_state = self.feature(state)
        encode_next_state = self.feature(next_state)
        # get pred action
        pred_action = torch.cat((encode_state, encode_next_state), 1)
        pred_action = self.inverse_net(pred_action)
        # ---------------------

        # get pred next state
        pred_next_state_feature_orig = torch.cat((encode_state, action), 1)
        pred_next_state_feature_orig = self.forward_net_1(pred_next_state_feature_orig)

        # residual
        for i in range(4):
            pred_next_state_feature = self.residual[i * 2](torch.cat((pred_next_state_feature_orig, action), 1))
            pred_next_state_feature_orig = self.residual[i * 2 + 1](
                torch.cat((pred_next_state_feature, action), 1)) + pred_next_state_feature_orig

        pred_next_state_feature = self.forward_net_2(torch.cat((pred_next_state_feature_orig, action), 1))

        real_next_state_feature = encode_next_state
        return real_next_state_feature, pred_next_state_feature, pred_action

class Reshape(nn.Module):
    def __init__(self, shape):
        super(Reshape, self).__init__()
        self.shape = shape
    def forward(self, x):
        return x.reshape(-1, *self.shape)

class VAEDensity(nn.Module):
    def __init__(self,
                 code_dim=128,
                 frames=1
                 ):
        super().__init__()

        self.enc = nn.Sequential(
            Scale(1/255),
            layer_init(nn.Conv2d(frames, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
        )
        self.enc_mu = nn.Linear(3136, code_dim)
        self.enc_logvar = nn.Linear(3136, code_dim)
        self.dec = nn.Sequential(
            nn.Linear(code_dim, 3136),
            Reshape((64, 7, 7)),
            layer_init(nn.ConvTranspose2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            layer_init(nn.ConvTranspose2d(64, 32, 4, stride=2)),
            nn.ReLU(),
            layer_init(nn.ConvTranspose2d(32, frames, 8, stride=4)),
            nn.Sigmoid(),
            Scale(255.),
        )

    def forward(self, x):
        x = self.enc(x)
        mu = self.enc_mu(x)
        logsigma = self.enc_logvar(x)
        sigma = logsigma.exp()
        eps = torch.randn_like(sigma)
        z = eps.mul(sigma).add_(mu)

        recon_x = self.dec(z)
        return recon_x, mu, logsigma
