import numpy as np
import torch
import torch.jit as jit
import torch.nn as nn
from torch.nn import utils
import torch.nn.functional as F
from functools import partial
import threading
from threading import Thread
from utils import init
import torchvision

def get_resnet(name, pretrained=False):
    resnets = {
        'resnet18': torchvision.models.resnet18(pretrained=pretrained),
        'resnet50': torchvision.models.resnet50(pretrained=pretrained),
    }
    if name not in resnets.keys():
        raise KeyError(f'{name} is not a valid ResNet version')
    return resnets[name]

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

class SimCLR(nn.Module):
    """
    We opt for simplicity and adopt the commonly used ResNet (He et al., 2016)
    to obtain hi = f(x ̃i) = ResNet(x ̃i) where hi ∈ Rd is the output after the
    average pooling layer.
    """

    def __init__(self, encoder_name='resnet18', projection_dim=64):
        super(SimCLR, self).__init__()

        self.encoder = get_resnet(encoder_name, pretrained=False)

        # get dimensions of last fully-connected layer of encoder
        # (2048 for resnet50, 512 for resnet18)
        self.n_features = self.encoder.fc.in_features

        # replace the fc layer with an Identity function
        self.encoder.fc = Identity()

        # use a MLP with one hidden layer to obtain
        # z_i = g(h_i) = W(2)σ(W(1)h_i) where σ is a ReLU non-linearity.
        self.projector = nn.Sequential(
            nn.Linear(self.n_features, self.n_features, bias=False),
            nn.ReLU(),
            nn.Linear(self.n_features, projection_dim, bias=False),
        )

    def forward(self, x_i):
        h_i = self.encoder(x_i)
        # h_j = self.encoder(x_j)

        z_i = self.projector(h_i)
        # z_j = self.projector(h_j)
        return z_i

class StateEncoder(nn.Module):
    def __init__(self, num_inputs, hidden_size):
        super().__init__()
        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )
        self.encoder = nn.Sequential(
            init_(nn.Linear(num_inputs, hidden_size)),
            nn.ReLU(),
        )

    # @jit.export
    # def __getstate__(self):
    #     return self.state_dict()

    # @jit.export
    # def __setstate__(self, d):
    #     self.load_state_dict(d)

    def forward(self, x):
        encoded_states = self.encoder(x)
        return encoded_states

class ConvStateEncoder(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, hidden_size)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class DNDKeyGenerator(nn.Module):
    def __init__(self, input_size, key_size):
        super().__init__()
        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )
        self.key_generator = nn.Sequential(
            init_(nn.Linear(input_size, key_size)),
            nn.ReLU(),
        )

    def forward(self, x):
        keys = self.key_generator(x)
        return keys

class LinMemReader(nn.Module):
    def __init__(self, encoded_obs_dim, context_dim, mem_len, msg_size):
        super().__init__()
        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )
        self.context_layer = init_(nn.Linear(encoded_obs_dim, context_dim))
        self.mem_process_layer = init_(nn.Linear(mem_len * msg_size, context_dim))
        self.mem_read_layer = nn.Sequential(
            init_(nn.Linear(context_dim * 2, mem_len)),
            nn.Sigmoid(),)


    # @jit.export
    # def __getstate__(self):
    #     return self.state_dict()

    # @jit.export
    # def __setstate__(self, d):
    #     self.load_state_dict(d)

    def forward(self, x, mem_content):
        context = self.context_layer(x)
        mem_context = self.mem_process_layer(mem_content.view(mem_content.size(0), -1))
        mem_read_mask = self.mem_read_layer(torch.cat([context, mem_context], dim = -1))
        return mem_read_mask

class RLHeadNetworks(nn.Module):
    def __init__(self, num_inputs, hidden_size, use_spectral_norm = False):
        super().__init__()
        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )

        if(use_spectral_norm):
            self.actor = nn.Sequential(
                init_(nn.Linear(num_inputs, hidden_size)),
                nn.ReLU(),
                init_(utils.spectral_norm(nn.Linear(hidden_size, hidden_size))),
                nn.ReLU(),
            )

            self.critic = nn.Sequential(
                init_(nn.Linear(num_inputs, hidden_size)),
                nn.ReLU(),
                init_(utils.spectral_norm(nn.Linear(hidden_size, hidden_size))),
                nn.ReLU(),
            )
        else:
            self.actor = nn.Sequential(
                init_(nn.Linear(num_inputs, hidden_size)),
                nn.ReLU(),
                init_(nn.Linear(hidden_size, hidden_size)),
                nn.ReLU(),
            )

            self.critic = nn.Sequential(
                init_(nn.Linear(num_inputs, hidden_size)),
                nn.ReLU(),
                init_(nn.Linear(hidden_size, hidden_size)),
                nn.ReLU(),
            )
        self.critic_linear = init_(nn.Linear(hidden_size, 1))

    # @jit.export
    # def __getstate__(self):
    #     return self.state_dict()

    # @jit.export
    # def __setstate__(self, d):
    #     self.load_state_dict(d)

    def forward(self, x):
        hidden_critic = self.critic_linear(self.critic(x))
        hidden_actor = self.actor(x)
        return hidden_critic, hidden_actor

class CommHeadNetworks(nn.Module):
    def __init__(self, num_inputs, hidden_size, num_comm_outputs, use_spectral_norm = False):
        super().__init__()
        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )

        if(use_spectral_norm):
            self.comm_actor = nn.Sequential(
                init_(nn.Linear(num_inputs, hidden_size)),
                nn.ReLU(),
                init_(utils.spectral_norm(nn.Linear(hidden_size, hidden_size))),
                nn.ReLU(),
                init_(nn.Linear(hidden_size, num_comm_outputs)),
                nn.Sigmoid(),
            )
        else:
            self.comm_actor = nn.Sequential(
                init_(nn.Linear(num_inputs, hidden_size)),
                nn.ReLU(),
                init_(nn.Linear(hidden_size, hidden_size)),
                nn.ReLU(),
                init_(nn.Linear(hidden_size, num_comm_outputs)),
                nn.Sigmoid(),
            )

    # @jit.export
    # def __getstate__(self):
    #     return self.state_dict()

    # @jit.export
    # def __setstate__(self, d):
    #     self.load_state_dict(d)

    def forward(self, x):
        comm_actor_m = self.comm_actor(x)
        return comm_actor_m

class CommSplitHeadNetworks(nn.Module):
    def __init__(self, num_inputs, hidden_size, num_comm_outputs, use_spectral_norm = False):
        super().__init__()
        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )

        if(use_spectral_norm):
            self.comm_actor_s = nn.Sequential(
                init_(nn.Linear(num_inputs, hidden_size)),
                nn.ReLU(),
                init_(utils.spectral_norm(nn.Linear(hidden_size, hidden_size))),
                nn.ReLU(),
                init_(nn.Linear(hidden_size, int(num_comm_outputs / 2))),
                nn.Sigmoid(),
            )
            self.comm_actor_i = nn.Sequential(
                init_(nn.Linear(num_inputs, hidden_size)),
                nn.ReLU(),
                init_(utils.spectral_norm(nn.Linear(hidden_size, hidden_size))),
                nn.ReLU(),
                init_(nn.Linear(hidden_size, int(num_comm_outputs / 2))),
                nn.Sigmoid(),
            )
        else:
            self.comm_actor_s = nn.Sequential(
                init_(nn.Linear(num_inputs, hidden_size)),
                nn.ReLU(),
                init_(nn.Linear(hidden_size, hidden_size)),
                nn.ReLU(),
                init_(nn.Linear(hidden_size, int(num_comm_outputs / 2))),
                nn.Sigmoid(),
            )
            self.comm_actor_i = nn.Sequential(
                init_(nn.Linear(num_inputs, hidden_size)),
                nn.ReLU(),
                init_(nn.Linear(hidden_size, hidden_size)),
                nn.ReLU(),
                init_(nn.Linear(hidden_size, int(num_comm_outputs / 2))),
                nn.Sigmoid(),
            )

    def forward(self, x):
        comm_actor_m = torch.cat((self.comm_actor_s(x), self.comm_actor_i(x)), -1)
        return comm_actor_m


class CommHeadProjector(nn.Module):
    def __init__(self, num_comm_outputs, hidden_size = None):
        super().__init__()
        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )
        hidden_size = hidden_size if hidden_size != None else num_comm_outputs
        self.projector = nn.Sequential(
            init_(nn.Linear(num_comm_outputs, hidden_size)),
            nn.ReLU(),
            init_((utils.spectral_norm(nn.Linear(hidden_size, hidden_size)))),
            nn.ReLU(),
            init_(nn.Linear(hidden_size, num_comm_outputs)),
            nn.Sigmoid(),
        )

        # self.projector = nn.Sequential(
        #     utils.spectral_norm(nn.Linear(num_comm_outputs, num_comm_outputs)),
        #     nn.Sigmoid(),
        # )

    def forward(self, m):
        projected_input = self.projector(m)
        return projected_input

class CommHeadAligner(nn.Module):
    def __init__(self, num_inputs, hidden_size, num_comm_outputs):
        super().__init__()
        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )
        # This is for pp
        self.comm_decoder = nn.Sequential(
            init_(nn.Linear(num_comm_outputs, hidden_size)),
            nn.ReLU(),
            init_(nn.Linear(hidden_size, hidden_size)),
            nn.ReLU(),
            init_(nn.Linear(hidden_size, num_inputs)),
        )
        # This is for TJEasy
        # self.comm_decoder = nn.Sequential(
        #     init_(nn.Linear(num_comm_outputs, hidden_size)),
        #     nn.ReLU(),
        #     init_(nn.Linear(hidden_size, hidden_size)),
        #     nn.ReLU(),
        #     init_(nn.Linear(hidden_size, num_inputs)),
        # )

    def forward(self, m):
        decoded_input = self.comm_decoder(m)
        return decoded_input

class CommActHeadAligner(nn.Module):
    def __init__(self, num_comm_outputs, hidden_size):
        super().__init__()
        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )
        self.comm_decoder = nn.Sequential(
            init_(nn.Linear(num_comm_outputs, hidden_size)),
            nn.ReLU(),
            init_(nn.Linear(hidden_size, hidden_size)),
            nn.ReLU(),
        )

    def forward(self, m):
        decoded_input = self.comm_decoder(m)
        return decoded_input

class CommMIHeadAligner(nn.Module):
    def __init__(self, num_comm_outputs, num_agents, hidden_size):
        super().__init__()
        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )
        self.comm_decoder = nn.Sequential(
            init_(nn.Linear(num_comm_outputs * num_agents, hidden_size)),
            nn.ReLU(),
            init_(nn.Linear(hidden_size, hidden_size)),
            nn.ReLU(),
        )

    def forward(self, m):
        decoded_input = self.comm_decoder(m)
        return F.normalize(decoded_input)


class CommMMHeadAligner(nn.Module):
    def __init__(self, num_comm_outputs, num_agents, hidden_size):
        super().__init__()
        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )
        self.comm_decoder = nn.Sequential(
            init_(nn.Linear(num_comm_outputs * num_agents, hidden_size)),
            nn.ReLU(),
            init_(nn.Linear(hidden_size, hidden_size)),
            nn.ReLU(),
            init_(nn.Linear(hidden_size, num_comm_outputs)),
            nn.Sigmoid(),
        )

    def forward(self, m):
        decoded_input = self.comm_decoder(m)
        return decoded_input

class CommGatedeadNetworks(nn.Module):
    def __init__(self, num_inputs, hidden_size, num_comm_outputs):
        super().__init__()
        init_ = lambda m: init(
            m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), np.sqrt(2)
        )
        self.comm_actor = nn.Sequential(
            init_(nn.Linear(num_inputs, hidden_size)),
            nn.ReLU(),
        )

        self.comm_m_output = nn.Sequential(
            init_(nn.Linear(hidden_size, num_comm_outputs)),
            nn.Sigmoid(),
        )

        self.comm_gate_output = nn.Sequential(
            init_(nn.Linear(hidden_size, 2)),
            nn.LogSoftmax(dim = -1),
        )

        self.num_comm_outputs = num_comm_outputs

    # @jit.export
    # def __getstate__(self):
    #     return self.state_dict()

    # @jit.export
    # def __setstate__(self, d):
    #     self.load_state_dict(d)

    @jit.export
    def gate_forward(self, comm_h):
        comm_gate_output_log_prob = self.comm_gate_output(comm_h)
        exp_comm_gate_output = torch.exp(comm_gate_output_log_prob)
        sampled_gated_mask = torch.multinomial(exp_comm_gate_output, 1)
        sampled_gated_mask = sampled_gated_mask.expand(sampled_gated_mask.size(0), self.num_comm_outputs)
        return sampled_gated_mask

    #@profile
    def forward(self, x):
        # Compute comm head latent
        comm_h= self.comm_actor(x)

        # # Non-threading
        # Compute message
        comm_actor_m = self.comm_m_output(comm_h)
        # Compute gate mask
        comm_gate_output_log_prob = self.comm_gate_output(comm_h)
        exp_comm_gate_output = torch.exp(comm_gate_output_log_prob)
        sampled_gated_mask = torch.multinomial(exp_comm_gate_output, 1)
        sampled_gated_mask = sampled_gated_mask.expand(sampled_gated_mask.size(0), self.num_comm_outputs)

        # Threading
        # m_head_thread = ThreadWithReturnValue(target = self.comm_m_output, args=(comm_h, ))
        # m_head_thread.start()
        # gate_head_thread = ThreadWithReturnValue(target = self.gate_forward, args=(comm_h, ))
        # gate_head_thread.start()
        # comm_actor_m = m_head_thread.join()
        # sampled_gated_mask = gate_head_thread.join()

        # Masked message
        return comm_actor_m * sampled_gated_mask
