import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical
from utils import layer_init


class QNetwork(nn.Module):
    def __init__(self, env, args):
        super().__init__()

        # Define the hidden layer activation or Hadamard module
        if args.activation == 'tanh':
            self.hidden = nn.Sequential(nn.Linear(3136, args.latent_dim), nn.Tanh())
        elif args.activation == 'sigmoid':
            self.hidden = nn.Sequential(nn.Linear(3136, args.latent_dim), nn.Sigmoid())
        elif args.activation == 'selu':
            self.hidden = nn.Sequential(nn.Linear(3136, args.latent_dim), nn.SELU())
        elif args.activation == 'linear':
            self.hidden = nn.Sequential(nn.Linear(3136, args.latent_dim), nn.Identity())
        elif args.activation == 'relu':
            self.hidden = nn.Sequential(nn.Linear(3136, args.latent_dim), nn.ReLU())
        elif args.activation == 'tanh_HR':
            self.hidden = Hadamard_Module(3136, args.latent_dim, activation='tanh')
        elif args.activation == 'double_HR':
            self.hidden = Double_Hadamard_Module(3136, args.latent_dim, activation='tanh')
        elif args.activation == 'sigmoid_HR':
            self.hidden = Hadamard_Module(3136, args.latent_dim, activation='sigmoid')
        elif args.activation == 'relu_HR':
            self.hidden = Hadamard_Module(3136, args.latent_dim, activation='relu')
        elif args.activation == 'plustanh':
            self.hidden = Addition_Module(3136, args.latent_dim, activation='tanh')
        elif args.activation == 'layernorm':
            self.hidden = nn.Sequential(nn.Linear(3136, args.latent_dim), nn.LayerNorm(args.latent_dim), nn.Tanh())
        else:
            raise ValueError("Invalid activation function, Please add you own")

        # Define the neural network model
        self.network = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            self.hidden,
            nn.Linear(args.latent_dim, env.single_action_space.n),
        )

    def forward(self, x):
        representation = self.network[:-1](x / 255.0)
        q_values = self.network[-1](representation)
        print(self.hidden)
        return q_values, representation

    # The final representation z_{t} is the product of these two representations. This function is for plotting purposes.
    def get_individual_representations(self, x):
        flattened = self.network[:-2](x / 255.0)
        return self.network[-2].lan(flattened), self.network[-2].main_layer(flattened)


class Hadamard_Module(nn.Module):
    def __init__(self, input_dim, output_dim, activation='sigmoid'):
        super(Hadamard_Module, self).__init__()
        self.main_activation = nn.Sigmoid() if activation == 'sigmoid' else nn.Tanh() if activation == 'tanh' else nn.ReLU() if activation == 'relu' else None
        self.main_layer = nn.Linear(input_dim, output_dim)
        self.hadamard = Hadamard(input_dim=input_dim, output_dim=output_dim, activation=activation)

    def forward(self, x):
        lan_representation = self.hadamard(x)
        main_pre_activation = self.main_layer(x)
        main_representation = self.main_activation(main_pre_activation)
        final_representation = main_representation * lan_representation
        return final_representation


class Double_Hadamard_Module(nn.Module):
    def __init__(self, input_dim, output_dim, activation='sigmoid'):
        super(Double_Hadamard_Module, self).__init__()
        if activation == 'sigmoid':
            self.main_layer = nn.Sequential(nn.Linear(input_dim, output_dim),
                                            nn.Sigmoid())
        elif activation == 'tanh':
            self.main_layer = nn.Sequential(nn.Linear(input_dim, output_dim),
                                            nn.Tanh())
        self.hadamard1 = Hadamard(input_dim=input_dim, output_dim=output_dim, activation=activation)
        self.hadamard2 = Hadamard(input_dim=input_dim, output_dim=output_dim, activation=activation)

    def forward(self, x):
        extra_representation1 = self.hadamard1(x)
        extra_representation2 = self.hadamard2(x)
        main_representation = self.main_layer(x)
        hadamard_representation = main_representation * extra_representation1 * extra_representation2
        return hadamard_representation


class Addition_Module(nn.Module):
    def __init__(self, input_dim, output_dim, activation='sigmoid'):
        super(Addition_Module, self).__init__()
        if activation == 'sigmoid':
            self.main_layer = nn.Sequential(nn.Linear(input_dim, output_dim),
                                            nn.Sigmoid())
        elif activation == 'tanh':
            self.main_layer = nn.Sequential(nn.Linear(input_dim, output_dim),
                                            nn.Tanh())
        self.hadamard = Hadamard(input_dim=input_dim, output_dim=output_dim, activation=activation)

    def forward(self, x):
        extra_representation = self.hadamard(x)
        main_representation = self.main_layer(x)
        # Addition Rather than Multiplication
        hadamard_representation = main_representation + extra_representation
        return hadamard_representation


class Hadamard(nn.Module):
    def __init__(self, input_dim, output_dim, activation='sigmoid'):
        super(Hadamard, self).__init__()
        self.lan = nn.Linear(input_dim, output_dim)
        self.activation = nn.Sigmoid() if activation == 'sigmoid' else nn.Tanh() if activation == 'tanh' else nn.ReLU() if activation == 'relu' else None

    def forward(self, x):
        extra_representation_no_activation = self.lan(x)
        extra_representation = self.activation(extra_representation_no_activation)
        return extra_representation


class PPOAgent(nn.Module):
    def __init__(self, envs, activation, latent_dim):
        super().__init__()
        if activation == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'sigmoid':
            self.activation = nn.Sigmoid()

        self.latent_dim = latent_dim

        self.conv1 = layer_init(nn.Conv2d(4, 32, 8, stride=4))
        self.conv2 = layer_init(nn.Conv2d(32, 64, 4, stride=2))
        self.conv3 = layer_init(nn.Conv2d(64, 64, 3, stride=1))
        self.linear1 = layer_init(nn.Linear(64 * 7 * 7, self.latent_dim))

        self.actor = layer_init(nn.Linear(self.latent_dim, envs.single_action_space.n), std=0.01)
        self.critic = layer_init(nn.Linear(self.latent_dim, 1), std=1)

    def get_value(self, x):
        latent1 = torch.relu(self.conv1(x / 255.0))
        latent2 = torch.relu(self.conv2(latent1))
        latent3 = torch.relu(self.conv3(latent2))
        latent3 = latent3.flatten(1)
        final_latent = self.activation(self.linear1(latent3))
        return self.critic(final_latent)

    def get_pre_activation_value(self, x):
        latent1 = torch.relu(self.conv1(x / 255.0))
        latent2 = torch.relu(self.conv2(latent1))
        latent3 = torch.relu(self.conv3(latent2))
        latent3 = latent3.flatten(1)
        final_latent = self.linear1(latent3)
        return final_latent

    def get_action_and_value(self, x, action=None):
        latent1 = torch.relu(self.conv1(x / 255.0))
        latent2 = torch.relu(self.conv2(latent1))
        latent3 = torch.relu(self.conv3(latent2))
        latent3 = latent3.flatten(1)
        final_latent = self.activation(self.linear1(latent3))
        logits = self.actor(final_latent)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(final_latent)


class PPO_HR_Agent(nn.Module):
    def __init__(self, envs, latent_dim):
        super().__init__()
        self.activation = 'relu'
        self.latent_dim = latent_dim

        self.network = nn.Sequential(
            layer_init(nn.Conv2d(4, 32, 8, stride=4)),
            nn.ReLU(),
            layer_init(nn.Conv2d(32, 64, 4, stride=2)),
            nn.ReLU(),
            layer_init(nn.Conv2d(64, 64, 3, stride=1)),
            nn.ReLU(),
            nn.Flatten(),
            layer_init(nn.Linear(64 * 7 * 7, self.latent_dim)),
            nn.ReLU())

        self.conv1 = layer_init(nn.Conv2d(4, 32, 8, stride=4))
        self.conv2 = layer_init(nn.Conv2d(32, 64, 4, stride=2))
        self.conv3 = layer_init(nn.Conv2d(64, 64, 3, stride=1))
        self.linear1 = layer_init(nn.Linear(64 * 7 * 7, self.latent_dim))
        self.linear1_extra = layer_init(nn.Linear(64 * 7 * 7, self.latent_dim))

        self.actor = layer_init(nn.Linear(self.latent_dim, envs.single_action_space.n), std=0.01)
        self.critic = layer_init(nn.Linear(self.latent_dim, 1), std=1)

    def get_value(self, x):
        latent1 = torch.relu(self.conv1(x / 255.0))
        latent2 = torch.relu(self.conv2(latent1))
        latent3 = torch.relu(self.conv3(latent2))
        latent3 = latent3.flatten(1)
        base_latent = torch.tanh(self.linear1(latent3))
        mask_latent = torch.tanh(self.linear1_extra(latent3))
        final_latent = base_latent * mask_latent
        return self.critic(final_latent)

    def get_pre_activation_value(self, x):
        latent1 = torch.relu(self.conv1(x / 255.0))
        latent2 = torch.relu(self.conv2(latent1))
        latent3 = torch.relu(self.conv3(latent2))
        latent3 = latent3.flatten(1)
        final_latent = torch.tanh(self.linear1(latent3)) * torch.tanh(self.linear1_extra(latent3))
        return final_latent

    def get_action_and_value(self, x, action=None):
        latent1 = torch.relu(self.conv1(x / 255.0))
        latent2 = torch.relu(self.conv2(latent1))
        latent3 = torch.relu(self.conv3(latent2))
        latent3 = latent3.flatten(1)
        base_latent = torch.tanh(self.linear1(latent3))
        mask_latent = torch.tanh(self.linear1_extra(latent3))
        final_latent = base_latent * mask_latent
        logits = self.actor(final_latent)
        probs = Categorical(logits=logits)
        if action is None:
            action = probs.sample()
        return action, probs.log_prob(action), probs.entropy(), self.critic(final_latent)
