import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from tcn import TemporalConvNet


class VAE(nn.Module):
    def __init__(self, alpha=1, beta=1, gamma=1, latent_n=1, groups={}, traj_len=0, traj_ch=0, device="cpu"):
        super(VAE, self).__init__()
        layers = []
        self.latent_n = latent_n
        self.traj_len = traj_len
        self.traj_ch = traj_ch
        self.hidden_conv_ch = 20
        self.groups = groups
        self.groups_n = len(groups.keys())
        self.device = device
        
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        
        # IMAGE ENCODER
        self.conv_channels = [32,32,64,64,64]
        self.dense_channels = [1024, 256, 4]
    
        kernel_size=7
        self.img_encoder_conv_0 = nn.Conv2d(3, self.conv_channels[0], kernel_size, padding=3, stride=2) # (64, 64)
        kernel_size=5
        self.img_encoder_conv_1 = nn.Conv2d(self.conv_channels[0], self.conv_channels[1], kernel_size, padding=2, stride=2) # (32, 32)
        kernel_size=3
        self.img_encoder_conv_2 = nn.Conv2d(self.conv_channels[1], self.conv_channels[2], kernel_size, padding=1, stride=2) # (16, 16)
        self.img_encoder_conv_3 = nn.Conv2d(self.conv_channels[2], self.conv_channels[3], kernel_size, padding=1, stride=2) # (8, 8)
        self.img_encoder_conv_4 = nn.Conv2d(self.conv_channels[3], self.conv_channels[4], kernel_size, padding=1, stride=2) # (4, 4)

        self.img_encoder_dense_0 = nn.Linear(self.dense_channels[0], self.dense_channels[1])
        self.img_encoder_dense_1 = nn.Linear(self.dense_channels[1], self.dense_channels[2])
        
        # TRAJ ENCODER
        self.traj_dense = [256, 32]
        kernel_size = 3
        padding = 1
        self.traj_encoder_1Gram = nn.Conv1d(self.traj_ch, self.hidden_conv_ch, kernel_size, padding=padding)
        
        kernel_size = 5
        padding = 2
        self.traj_encoder_2Gram = nn.Conv1d(self.traj_ch, self.hidden_conv_ch, kernel_size, padding=padding)
        
        kernel_size = 7
        padding = 3
        self.traj_encoder_3Gram = nn.Conv1d(self.traj_ch, self.hidden_conv_ch, kernel_size, padding=padding)
        
        self.traj_encoder_dense_0 = nn.Linear(self.hidden_conv_ch * 3 * self.traj_len, self.traj_dense[0])
        self.traj_encoder_dense_1 = nn.Linear(self.traj_dense[0], self.traj_dense[1])
        
        # TRAJ DECODER
        self.traj_decoder_tcn = TemporalConvNet(self.latent_n + 1 + 4, [self.hidden_conv_ch, self.hidden_conv_ch, self.traj_ch], kernel_size=5)
        
        
        # FC ENCODER
        self.fc_encoder_dense_0 = nn.Linear(self.traj_dense[-1] + 4, 32) # 16 for traj embedding, 2 for GT cube coordinates 
        self.fc_encoder_dense_1 = nn.Linear(32, 32)
        self.fc_encoder_mu = nn.Linear(32, self.latent_n)
        self.fc_encoder_ln_var = nn.Linear(32, self.latent_n)

        
        # CLASSIFIERS
        
        self.img_encoder = [self.img_encoder_conv_0,
                            self.img_encoder_conv_1,
                            self.img_encoder_conv_2,
                            self.img_encoder_conv_3,
                            self.img_encoder_conv_4,
                            self.img_encoder_dense_0,
                            self.img_encoder_dense_1]
        
        self.classifiers = nn.ModuleList([nn.Linear(1, len(items)) for key, items in self.groups.items()])
        
        self.traj_encoder = [self.traj_encoder_1Gram,
                             self.traj_encoder_2Gram,
                             self.traj_encoder_3Gram,
                             self.traj_encoder_dense_0,
                             self.traj_encoder_dense_1]
        
        self.fc_encoder = [self.fc_encoder_dense_0,
                           self.fc_encoder_dense_1,
                           self.fc_encoder_mu,
                           self.fc_encoder_ln_var]
        
        self.init_weights()
    
    def init_weights(self):
        for i in range(len(self.img_encoder)):
            self.img_encoder[i].weight.data.normal_(0, 0.01)
            
        for i in range(len(self.traj_encoder)):
            self.traj_encoder[i].weight.data.normal_(0, 0.01)
            
        for i in range(len(self.fc_encoder)):
            self.fc_encoder[i].weight.data.normal_(0, 0.01)
        
        for i in range(self.groups_n):
            self.classifiers[i].weight.data.normal_(0, 0.01)

    def traj_encode(self, x):
        one_grams = self.traj_encoder_1Gram(x)
        two_grams = self.traj_encoder_2Gram(x)
        three_grams = self.traj_encoder_3Gram(x)
        
        concat_grams = torch.cat((one_grams, two_grams, three_grams), dim=2)
        flatten_concat_grams = torch.flatten(concat_grams, start_dim=1)
                        
        dense_0_encoded = F.leaky_relu(self.traj_encoder_dense_0(flatten_concat_grams))
        traj_embed = F.leaky_relu(self.traj_encoder_dense_1(dense_0_encoded))
        
        return traj_embed
    
    def img_encode(self, x):
        
        conv_0_encoded = F.leaky_relu(self.img_encoder_conv_0(x))
        conv_1_encoded = F.leaky_relu(self.img_encoder_conv_1(conv_0_encoded))
        conv_2_encoded = F.leaky_relu(self.img_encoder_conv_2(conv_1_encoded))
        conv_3_encoded = F.leaky_relu(self.img_encoder_conv_3(conv_2_encoded))
        conv_4_encoded = F.leaky_relu(self.img_encoder_conv_4(conv_3_encoded))

        reshaped_encoded = torch.flatten(conv_4_encoded, start_dim=1)
        dense_0_encoded = F.leaky_relu(self.img_encoder_dense_0(reshaped_encoded))
        img_embed = F.leaky_relu(self.img_encoder_dense_1(dense_0_encoded))
        
        return img_embed
    
    def fc_encode(self, traj_embed, img_embed):
        
        concat = torch.cat((traj_embed, img_embed), dim=-1)
        dense_0_encoded = F.leaky_relu(self.fc_encoder_dense_0(concat))
        dense_1_encoded = F.leaky_relu(self.fc_encoder_dense_1(dense_0_encoded))
        mu = self.fc_encoder_mu(dense_1_encoded)
        ln_var = self.fc_encoder_ln_var(dense_1_encoded)
        
        return mu, ln_var
    
    def encode(self, traj_in, img_in):
        
        traj_embed = self.traj_encode(traj_in)
        img_embed = self.img_encode(img_in)
        
        mu, ln_var = self.fc_encode(traj_embed, img_embed)
        
        return mu, ln_var, img_embed
    
    def decode(self, z, img_embed):
        tmp = np.linspace(0, 1, self.traj_len)
        t_idxs = torch.from_numpy(np.tile(tmp, (z.shape[0], 1)).astype(np.float32)[:, :, None]).to(self.device)
        
        img_embed = torch.repeat_interleave(img_embed, self.traj_len, axis=0).view(z.shape[0], self.traj_len, -1)
        
        z_tiled = torch.repeat_interleave(z, self.traj_len, dim=0).view(z.shape[0], self.traj_len, self.latent_n)
    
        concat = torch.cat((z_tiled, img_embed, t_idxs), dim=2)
        concat = concat.permute(0, 2, 1)
        
        tcn_decoded = self.traj_decoder_tcn(concat)
        
        return tcn_decoded
    
    def predict_labels(self, z, softmax=False):
        result = []
        
        for i in range(self.groups_n):
            prediction = self.classifiers[i](z[:, i, None])

            # need the check because the softmax_cross_entropy has a softmax in it
            if softmax:
                result.append(F.softmax(prediction))
            else:
                result.append(prediction)

        return result
    
    def get_latent(self, traj_in, img_in):
        
        def reparameterize(mu, logvar):
            std = torch.exp(0.5*logvar)
            eps = torch.randn_like(std)
            return mu + eps*std
    
        mu, logvar, _ = self.encode(traj_in, img_in)
        z = reparameterize(mu, logvar)
        
        return z
    
    def forward(self, traj_in, img_in):
    
        def reparameterize(mu, logvar):
            std = torch.exp(0.5*logvar)
            eps = torch.randn_like(std)
            return mu + eps*std
    
        mu, logvar, img_embed = self.encode(traj_in, img_in)
        z = reparameterize(mu, logvar)
        
        labels = self.predict_labels(z)
        
        reconstr = self.decode(z, img_embed)
        
        return reconstr, labels, mu, logvar
    
    def get_loss(self):
        
        def loss(traj_in, traj_out, labels_in, labels_out, mu, logvar):
            rec = self.alpha * nn.MSELoss(reduction="none")(traj_out, traj_in)
            rec = torch.mean(torch.sum(rec.view(rec.shape[0], -1), dim=-1))
            
            label = 0
            for i in range(self.groups_n):
                label += self.gamma * nn.CrossEntropyLoss(ignore_index=100)(labels_out[i], labels_in[:, i])
            
            kld = self.beta * (-0.5) * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

            return rec + label + kld, rec, label, kld
    
        return loss