import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from tcn import TemporalConvNet

N = 5 # number of categotical distributions
K = 2 # number of models for each categorical distribution

def reparameterize(mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return mu + eps*std

def sample_gumbel(shape, eps=1e-20, device="cpu"):
    U = torch.rand(shape)
    if device != "cpu":
        U = U.cuda()
    return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature, device="cpu"):
    y = logits + sample_gumbel(logits.size(), device=device)
    return F.softmax(y / temperature, dim=-1)

def gumbel_softmax(logits, temperature, hard=False, device="cpu"):
    """
    ST-gumple-softmax
    input: [*, n_class]
    return: flatten --> [*, n_class] an one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature, device=device)
    
    if not hard:
        return y.view(-1, N * K)

    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    # Set gradients w.r.t. y_hard gradients w.r.t. y
    y_hard = (y_hard - y).detach() + y
    return y_hard.view(-1, N * K)

class VAE_GS(nn.Module):
    def __init__(self, alpha=1, beta=1, gamma=1, latent_n=1, groups={}, traj_len=0, traj_ch=0, device="cpu", tau=1, hard=False):
        super(VAE_GS, self).__init__()
        layers = []
        self.latent_n = latent_n
        self.traj_len = traj_len
        self.traj_ch = traj_ch
        self.hidden_conv_ch = 20
        self.groups = groups
        self.groups_n = len(groups.keys())
        self.device = device
        
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        
        self.tau = tau
        self.hard = hard
        
        # IMAGE ENCODER
        self.conv_channels = [32,32,64,64,64]
        self.dense_channels = [1024, 256, 4]
    
        kernel_size=7
        self.img_encoder_conv_0 = nn.Conv2d(3, self.conv_channels[0], kernel_size, padding=3, stride=2) # (64, 64)
        kernel_size=5
        self.img_encoder_conv_1 = nn.Conv2d(self.conv_channels[0], self.conv_channels[1], kernel_size, padding=2, stride=2) # (32, 32)
        kernel_size=3
        self.img_encoder_conv_2 = nn.Conv2d(self.conv_channels[1], self.conv_channels[2], kernel_size, padding=1, stride=2) # (16, 16)
        self.img_encoder_conv_3 = nn.Conv2d(self.conv_channels[2], self.conv_channels[3], kernel_size, padding=1, stride=2) # (8, 8)
        self.img_encoder_conv_4 = nn.Conv2d(self.conv_channels[3], self.conv_channels[4], kernel_size, padding=1, stride=2) # (4, 4)

        self.img_encoder_dense_0 = nn.Linear(self.dense_channels[0], self.dense_channels[1])
        self.img_encoder_dense_1 = nn.Linear(self.dense_channels[1], self.dense_channels[2])
        
        # TRAJ ENCODER
        self.traj_dense = [256, 32]
        kernel_size = 3
        padding = 1
        self.traj_encoder_1Gram = nn.Conv1d(self.traj_ch, self.hidden_conv_ch, kernel_size, padding=padding)
        
        kernel_size = 5
        padding = 2
        self.traj_encoder_2Gram = nn.Conv1d(self.traj_ch, self.hidden_conv_ch, kernel_size, padding=padding)
        
        kernel_size = 7
        padding = 3
        self.traj_encoder_3Gram = nn.Conv1d(self.traj_ch, self.hidden_conv_ch, kernel_size, padding=padding)
        
        self.traj_encoder_dense_0 = nn.Linear(self.hidden_conv_ch * 3 * self.traj_len, self.traj_dense[0])
        self.traj_encoder_dense_1 = nn.Linear(self.traj_dense[0], self.traj_dense[1])
        
        # TRAJ DECODER
        self.traj_decoder_tcn = TemporalConvNet(self.latent_n + 1 + 4, [self.hidden_conv_ch, self.hidden_conv_ch, self.traj_ch], kernel_size=5)
        
        
        # FC ENCODER
        self.fc_encoder_dense_0 = nn.Linear(self.traj_dense[-1] + 4, 32) # 16 for traj embedding, 2 for GT cube coordinates 
        self.fc_encoder_dense_1 = nn.Linear(32, 32)
        self.fc_encoder_mu = nn.Linear(32, self.latent_n-(K*N))
        self.fc_encoder_ln_var = nn.Linear(32, self.latent_n-(K*N))
        self.fc_encoder_logits = nn.Linear(32, K*N)

        
        # CLASSIFIERS
        
        self.img_encoder = [self.img_encoder_conv_0,
                            self.img_encoder_conv_1,
                            self.img_encoder_conv_2,
                            self.img_encoder_conv_3,
                            self.img_encoder_conv_4,
                            self.img_encoder_dense_0,
                            self.img_encoder_dense_1]
        
        self.classifiers = nn.ModuleList([nn.Linear(K, len(items)) for key, items in self.groups.items()])
        
        self.traj_encoder = [self.traj_encoder_1Gram,
                             self.traj_encoder_2Gram,
                             self.traj_encoder_3Gram,
                             self.traj_encoder_dense_0,
                             self.traj_encoder_dense_1]
        
        self.fc_encoder = [self.fc_encoder_dense_0,
                           self.fc_encoder_dense_1,
                           self.fc_encoder_mu,
                           self.fc_encoder_ln_var,
                           self.fc_encoder_logits]
        
        self.init_weights()
    
    def init_weights(self):
        for i in range(len(self.img_encoder)):
            self.img_encoder[i].weight.data.normal_(0, 0.01)
            
        for i in range(len(self.traj_encoder)):
            self.traj_encoder[i].weight.data.normal_(0, 0.01)
            
        for i in range(len(self.fc_encoder)):
            self.fc_encoder[i].weight.data.normal_(0, 0.01)
        
        for i in range(self.groups_n):
            self.classifiers[i].weight.data.normal_(0, 0.01)

    def traj_encode(self, x):
        one_grams = self.traj_encoder_1Gram(x)
        two_grams = self.traj_encoder_2Gram(x)
        three_grams = self.traj_encoder_3Gram(x)
        
        concat_grams = torch.cat((one_grams, two_grams, three_grams), dim=2)
        flatten_concat_grams = torch.flatten(concat_grams, start_dim=1)
                        
        dense_0_encoded = F.leaky_relu(self.traj_encoder_dense_0(flatten_concat_grams))
        traj_embed = F.leaky_relu(self.traj_encoder_dense_1(dense_0_encoded))
        
        return traj_embed
    
    def img_encode(self, x):
        
        conv_0_encoded = F.leaky_relu(self.img_encoder_conv_0(x))
        conv_1_encoded = F.leaky_relu(self.img_encoder_conv_1(conv_0_encoded))
        conv_2_encoded = F.leaky_relu(self.img_encoder_conv_2(conv_1_encoded))
        conv_3_encoded = F.leaky_relu(self.img_encoder_conv_3(conv_2_encoded))
        conv_4_encoded = F.leaky_relu(self.img_encoder_conv_4(conv_3_encoded))

        reshaped_encoded = torch.flatten(conv_4_encoded, start_dim=1)
        dense_0_encoded = F.leaky_relu(self.img_encoder_dense_0(reshaped_encoded))
        img_embed = F.leaky_relu(self.img_encoder_dense_1(dense_0_encoded))
        
        return img_embed
    
    def fc_encode(self, traj_embed, img_embed):
        
        concat = torch.cat((traj_embed, img_embed), dim=-1)
        dense_0_encoded = F.leaky_relu(self.fc_encoder_dense_0(concat))
        dense_1_encoded = F.leaky_relu(self.fc_encoder_dense_1(dense_0_encoded))
        mu = self.fc_encoder_mu(dense_1_encoded)
        ln_var = self.fc_encoder_ln_var(dense_1_encoded)
        q_y = self.fc_encoder_logits(dense_1_encoded)
        q_y = q_y.view(q_y.size(0), N, K)
        
        return mu, ln_var, q_y
    
    def encode(self, traj_in, img_in):
        
        traj_embed = self.traj_encode(traj_in)
        img_embed = self.img_encode(img_in)
        
        mu, ln_var, q_y = self.fc_encode(traj_embed, img_embed)
        
        return mu, ln_var, q_y, img_embed
    
    def decode(self, z, img_embed):
        tmp = np.linspace(0, 1, self.traj_len)
        t_idxs = torch.from_numpy(np.tile(tmp, (z.shape[0], 1)).astype(np.float32)[:, :, None]).to(self.device)
        
        img_embed = torch.repeat_interleave(img_embed, self.traj_len, axis=0).view(z.shape[0], self.traj_len, -1)
        
        z_tiled = torch.repeat_interleave(z, self.traj_len, dim=0).view(z.shape[0], self.traj_len, self.latent_n)
    
        concat = torch.cat((z_tiled, img_embed, t_idxs), dim=2)
        concat = concat.permute(0, 2, 1)
        
        tcn_decoded = self.traj_decoder_tcn(concat)
        
        return tcn_decoded
    
    def predict_labels(self, z, softmax=False):
        result = []
        
        for i in range(self.groups_n):
            t = self.latent_n - (K*N)
            prediction = self.classifiers[i](z[:, t+i*K : t+(i+1)*K])

            # need the check because the softmax_cross_entropy has a softmax in it
            if softmax:
                result.append(F.softmax(prediction))
            else:
                result.append(prediction)

        return result
    
    def get_latent(self, traj_in, img_in):
        mu, logvar, q_y, _ = self.encode(traj_in, img_in)
        z_norm = reparameterize(mu, logvar)
        z_cat = gumbel_softmax(q_y, self.tau, self.hard, device=self.device)
        z = torch.cat((z_norm, z_cat), dim=1)
        
        return z
    
    def forward(self, traj_in, img_in):
        mu, logvar, q_y, img_embed = self.encode(traj_in, img_in)
        
        z_norm = reparameterize(mu, logvar)
        z_cat = gumbel_softmax(q_y, self.tau, self.hard, device=self.device)
        z = torch.cat((z_norm, z_cat), dim=1)
            
        labels = self.predict_labels(z)
        
        reconstr = self.decode(z, img_embed)
        
        return reconstr, labels, mu, logvar, q_y
    
    def get_loss(self):
        
        def loss(traj_in, traj_out, labels_in, labels_out, mu, logvar, q_y):
            rec = self.alpha * nn.MSELoss(reduction="none")(traj_out, traj_in)
            rec = torch.mean(torch.sum(rec.view(rec.shape[0], -1), dim=-1))
            
            label = 0
            for i in range(self.groups_n):
                label += self.gamma * nn.CrossEntropyLoss(ignore_index=100)(labels_out[i], labels_in[:, i])
            
            kld_norm = (-0.5) * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            
            qy = F.softmax(q_y, dim=-1)
            log_ratio = torch.log(qy * K + 1e-20)
            kld_cat = torch.sum(qy * log_ratio)
            
            kld = self.beta * (kld_norm + kld_cat)

            return rec + label + kld, rec, label, kld
    
        return loss