import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from utils import utils
import layers


class GeometricConvs:
    def __init__(self, sampled_pts, rad, dist, device, blocks=1):
        self.sampled_pts = sampled_pts
        #self.ref_pts = ref_pts
        self.rad = rad
        self.device = device

        self.f_idx = {'enc': [None] * (len(rad) - 1), 'dec': [None] * (len(rad) - 1)}
        self.supp = {'enc': [None] * (len(rad) - 1), 'dec': [None] * (len(rad) - 1)}

        for level in tqdm(range(len(rad) - 1), desc='Determining encoder kernel supports'):
            self.f_idx['enc'][level], self.supp['enc'][level] = \
                utils.determine_kernel_support(sampled_pts[level // blocks], sampled_pts[(level + 1) // blocks], rad[level], dist)
            self.f_idx['enc'][level] = torch.tensor(self.f_idx['enc'][level], device=device)
        for level in tqdm(range(len(rad) - 1, 0, -1), desc='Determining decoder kernel supports'):
            self.f_idx['dec'][level - 1], self.supp['dec'][level - 1] = \
                utils.determine_kernel_support(sampled_pts[level // blocks], sampled_pts[(level - 1) // blocks], rad[level], dist)
            self.f_idx['dec'][level - 1] = torch.tensor(self.f_idx['dec'][level - 1], device=device)

        # np.savez('viz.npz', supp=self.supp, pts=ref_pts, trg=ref_trg, samples=sampled_pts)

        self.ptc_row = None
        self.ptc_col = None
        self.ptc_val = None
        self.tangent = None

    def create_layers(self, n_ker, conv_type, enc_or_dec, **kwargs):
        layer = nn.ModuleList()
        conv = getattr(GeometricConvs, conv_type)
        if enc_or_dec == 'enc':
            for i in range(len(n_ker) - 1):
                kwargs['level'] = i
                kwargs['enc_or_dec'] = 'enc'
                kwargs['n_ker'] = n_ker
                kwargs['idx'] = i
                conv_layer, kwargs2 = conv(self, **kwargs)
                layer.append(conv_layer(n_ker[i], n_ker[i+1], self.f_idx['enc'][i],
                                        **kwargs2))
        elif enc_or_dec == 'dec':
            for i in range(len(n_ker) - 1):
                kwargs['level'] = len(n_ker) - 2 - i
                kwargs['enc_or_dec'] = 'dec'
                kwargs['n_ker'] = n_ker
                kwargs['idx'] = i
                conv_layer, kwargs2 = conv(self, **kwargs)
                layer.append(conv_layer(n_ker[i], n_ker[i + 1],
                                        self.f_idx['dec'][-(i + 1)], **kwargs2))
        return layer

    def gat(self, n_ker, idx, level, enc_or_dec, n_heads, blocks=1, **kwargs):
        sample_idx = utils.get_sample_idx(self.sampled_pts, level, enc_or_dec, blocks=blocks)
        if idx == 0:
            feats_in = n_ker[idx]
        else:
            feats_in = n_ker[idx] * n_heads

        if idx == len(n_ker) - 2:
            feats_out = n_ker[idx + 1]
            att_output = True
        else:
            feats_out = n_ker[idx + 1] * n_heads
            att_output = False

        kwargs = {'sample_idx': sample_idx, 'n_heads': n_heads, 'feats_in_eff': feats_in, 'feats_out_eff': feats_out,
                  'att_output': att_output}
        conv_layer = layers.GraphAttention
        return conv_layer, kwargs


class MVAE(nn.Module):
    def __init__(self, sampled_pts, adj_mat, n_ker, rad, dist, n_latent, conv_type, device, ignore_int_fine=True,
                 **kwargs):
        super(MVAE, self).__init__()
        self.sampled_pts = sampled_pts
        n_pts_smp = [len(s) for s in sampled_pts]
        self.adj_mat = adj_mat

        self.n_ker = n_ker
        self.n_latent = n_latent
        self.device = device
        self.geo_convs = GeometricConvs(sampled_pts, rad, dist, device)

        self.enc_split = len(n_ker[0]) > 1
        self.inference_net = nn.ModuleList()

        if ignore_int_fine:
            start_idx = 1
        else:
            start_idx = 0

        calpha_flag = [False, True, False]
        for n_k, n_z, ca_flag in zip(n_ker[0], n_latent, calpha_flag):
            self.inference_net.append(Encoder(n_pts_smp, adj_mat, n_k, n_z, conv_type, self.geo_convs, self.enc_split,
                                              device, ca_flag, **kwargs))
        # self.com_layer = torch.nn.Sequential(nn.Linear(3, 16), nn.ReLU(), nn.Linear(16, 4))
        self.generative_net = Decoder(n_pts_smp, n_ker[1], sum(n_latent[start_idx:]), conv_type, self.geo_convs, device,
                                      **kwargs)

        # self.enc_split = len(n_ker[0]) > 1
        #self.inference_net = nn.ModuleList()
        #for n_k, n_z in zip(n_ker[0], n_latent):
        #    self.inference_net.append(Encoder(n_k, n_z, device, **kwargs))
        #self.generative_net = Decoder(n_pts_fin, n_ker[1], sum(n_latent), device, **kwargs)

    def sample(self, sample_sz, eps=None, seed=[1, 1], mean=None, std=1.0):
        if seed is not None:
            torch.manual_seed(seed[0])
        if eps is None:
            eps = torch.normal(mean=0.0, std=std, size=(sample_sz, self.n_latent[0]))
            for i in range(1, len(self.n_latent)):
                torch.manual_seed(seed[i])
                eps_temp = torch.normal(mean=0.0, std=std, size=(sample_sz, self.n_latent[i]))
                eps = torch.cat([eps, eps_temp], dim=-1)
            eps = eps.to(self.device)
            if mean is not None:
                eps *= torch.clamp(std, max=1.0)
                eps += mean
        return self.generative_net(eps), eps

    def infer(self, bond_lens, dihedrals, var_flag=False):
        z, mu, logvar, sigma = self.inference_net[0](bond_lens, var_flag)
        # for i in range(1, len(self.inference_net)):
        z_temp, mu_temp, logvar_temp, sigma_temp = self.inference_net[1](dihedrals, var_flag)
        z = torch.cat([z, z_temp], dim=-1)
        mu = torch.cat([mu, mu_temp], dim=-1)
        logvar = torch.cat([logvar, logvar_temp], dim=-1)
        sigma = torch.cat([sigma, sigma_temp], dim=-1)
        return z, mu, logvar, sigma


class Encoder(nn.Module):
    def __init__(self, n_pts_smp, adj_mat, n_ker, n_latent, conv_type, geo_convs, split, device, ca_flag, blocks=1,
                 **kwargs):
        super(Encoder, self).__init__()
        self.n_pts_smp = n_pts_smp
        self.adj_mat = adj_mat
        self.n_ker = n_ker
        self.n_latent = n_latent
        self.conv_type = conv_type
        self.split = split
        # self.split = split
        self.device = device
        self.ca_flag = ca_flag

        self.sig = nn.Tanh()

        if ca_flag:
            self.enc_layers = torch.nn.ModuleList()
            for idx, (f_in, f_out) in enumerate(zip(n_ker[:-1], n_ker[1:])):
                self.enc_layers.append(layers.CalphaBlock(f_in, f_out, 3, stride=2, bias_flag=True, padding=1))
            self.dense_mu = nn.Linear(n_ker[-1], n_latent)
            self.dense_logvar = nn.Linear(n_ker[-1], n_latent, bias=True)
        else:
            self.enc0 = layers.GraphAttentionEdge(n_ker[0], n_ker[0] * 4, 1, adj_mat, n_heads=kwargs['n_heads'],
                                                  bias_flag=True, feats_in_eff=n_ker[0],
                                                  feats_out_eff=n_ker[0] * 4 * kwargs['n_heads'])
            self.bn0 = nn.BatchNorm1d(n_ker[0] * 4 * kwargs['n_heads'])
            n_ker_copy = n_ker.copy()
            n_ker_copy[0] *= 4 * kwargs['n_heads']

            self.enc_layers = geo_convs.create_layers(n_ker_copy, conv_type, 'enc', **kwargs)

            self.bn_list = torch.nn.ModuleList()
            for i in range(len(n_ker) - 1):
                if i < len(n_ker) - 2 and conv_type == 'gat':
                    h = kwargs['n_heads']
                else:
                    h = 1
                self.bn_list.append(nn.BatchNorm1d(n_ker[i+1] * h))
            # self.dense_mu = nn.Linear(n_ker[-1], n_latent)
            # self.dense_logvar = nn.Linear(n_ker[-1], n_latent, bias=True)

            self.pool = nn.AdaptiveAvgPool1d(10)
            self.pool_ca = nn.AdaptiveAvgPool1d(10)
            self.dense_mu = nn.Linear(n_ker[-1] * 10, n_latent)
            self.dense_logvar = nn.Linear(n_ker[-1] * 10, n_latent, bias=True)

            # self.dense_mu = nn.Linear(n_ker[-1] * n_pts_smp[-1], n_latent)
            # self.dense_logvar = nn.Linear(n_ker[-1] * n_pts_smp[-1], n_latent, bias=True)

    def forward(self, x, var_flag=False, bn_agree=True):
        # if self.split:
        #     if self.n_ker[0] == 1:
        #         x = x[:, 0].unsqueeze_(1)
        #     elif self.n_ker[0] == 3:
        #         x = x[:, 1:]
        x_enc = self.encoder(x, bn_agree=bn_agree)
        z, mu, logvar, sigma = self.create_latent_var(x_enc, var_flag=var_flag)
        return z, mu, logvar, sigma

    def encoder(self, x, bn_agree=True):
        if not self.ca_flag:
            if (self.enc_layers[0].training and not bn_agree) or not self.enc_layers[0].training:
                for bn in self.bn_list:
                    bn.eval()
                self.bn0.eval()
            else:
                for bn in self.bn_list:
                    bn.train()
                self.bn0.train()
        
            x = self.enc0(x)
            x = self.bn0(x)
            x = F.relu(x)
            for layer, bn in zip(self.enc_layers, self.bn_list):
                x = layer(x)
                x = bn(x)
                x = F.relu(x)

        else:
            x = self.enc_layers[0](x)
            for layer in self.enc_layers[1:]:
                x = layer(x)
        return x

    def create_latent_var(self, x, cutoff=6.0, var_flag=True):
        # if hasattr(self, 'gap'):
        #     if self.gap:
        # x = torch.mean(x, dim=-1, keepdim=True)
        if not self.ca_flag:
            x = self.pool(x)
            x = x.reshape([x.shape[0], -1])
        else:
            # x = torch.reshape(x, [x.shape[0], x.shape[1], -1])
            # x = self.pool_ca(x)
            # x = x.reshape([x.shape[0], -1])
            x = torch.mean(x, [2, 3])
        mu = self.dense_mu(x)
        logvar = self.dense_logvar(x)
        mu = self.sig(mu)
        #logvar_clamp = torch.clamp(logvar, max=8)

        logvar = torch.clamp(logvar, max=2E1)  # This is to avoid instability in the next line
        sigma = torch.exp(0.5 * logvar)
        #sigma_clamp = torch.exp(0.5 * logvar_clamp)
        eps = torch.normal(mean=0.0, std=1.0, size=mu.shape).to(mu.device)
        eps = torch.max(eps, -cutoff * torch.ones_like(eps))
        eps = torch.min(eps, cutoff * torch.ones_like(eps))
       	#print('var flag', var_flag) 
        if var_flag:
            z = mu + eps * sigma
            #if sym_flag:
            #    z_reflect = mu - eps * sigma
            #    z = torch.cat([z, z_reflect], dim=0)
                #mu = torch.cat([mu, mu], dim=0)
                #logvar = torch.cat([logvar, logvar], dim=0)
                #sigma = torch.cat([sigma, sigma], dim=0)
        else:
            z = mu
        return z, mu, logvar, sigma


class Decoder(nn.Module):
    def __init__(self, n_pts_smp, n_ker, n_latent, conv_type, geo_convs, device, blocks=1, **kwargs):
        super(Decoder, self).__init__()
        self.n_pts_smp = n_pts_smp
        self.n_ker = n_ker
        self.n_latent = n_latent
        self.conv_type = conv_type
        self.device = device

        #n_ker = [lev for lev in n_ker for i in range(blocks)]

        self.dec_layers = geo_convs.create_layers(n_ker, conv_type, 'dec', **kwargs)

        #self.dec_layers = nn.ModuleList()
        # for idx, (i, j) in enumerate(zip(n_ker[:-1], n_ker[1:])):
        #     if idx == 0 or idx == 2:
        #         output_pad = 1
        #     else:
        #         output_pad = 0
        #     self.dec_layers.append(torch.nn.ConvTranspose1d(i, j, 3, stride=2, output_padding=output_pad))

        h = kwargs['n_heads']
        self.bn_list = nn.ModuleList()
        for i in range(len(n_ker) - 2): #3
            self.bn_list.append(nn.BatchNorm1d(n_ker[i+1] * h))

        self.dense_surf = nn.Linear(n_latent, n_pts_smp[-1] * n_ker[0])
        # self.dense_surf = nn.Linear(n_latent + 4, n_pts_smp[-1] * n_ker[0])
        self.dense_bn = nn.BatchNorm1d(n_ker[0])

        #n_block = 2
        #for idx, (i, j) in enumerate(zip(n_ker[:-1], n_ker[1:])):
        #    for _ in range(n_block -1):
        #        self.dec_layers.append(torch.nn.ConvTranspose1d(i, i, 3, padding=1))
        #        self.bn_list.append(nn.BatchNorm1d(i)) 
        #    if idx == 0 or idx == 2:
        #        output_pad = 1
        #    else:
        #        output_pad = 0
        #    self.dec_layers.append(torch.nn.ConvTranspose1d(i, j, 3, stride=2, output_padding=output_pad))
        #    if idx < len(n_ker) - 2:
        #        self.bn_list.append(nn.BatchNorm1d(j))
        
        #for idx, (i, j) in enumerate(zip(n_ker[:-1], n_ker[1:])):
        #    if idx % 2 == 0:
        #        self.dec_layers.append(torch.nn.ConvTranspose1d(i, j, 3, padding=1))
        #    else:
        #        if idx == 1 or idx == 5:
        #            output_pad = 1
        #        else:
        #            output_pad = 0
        #        self.dec_layers.append(torch.nn.ConvTranspose1d(i, j, 3, stride=2, output_padding=output_pad))

        #self.dense_surf = nn.Linear(n_latent, n_pts_fin * n_ker[0])
        #self.dense_bn = nn.BatchNorm1d(n_ker[0])
        #
        #self.bn_list = nn.ModuleList()
        #for i in range(len(n_ker) - 2):
        #    self.bn_list.append(nn.BatchNorm1d(n_ker[i+1]))

    def forward(self, z, bn_agree=True):

#        if (self.dec_layers[0].training and not bn_agree) or not self.dec_layers[0].training:
#            self.dense_bn.eval()
#            for bn in self.bn_list:
#                bn.eval()
#        else:
#            self.dense_bn.train()
#            for bn in self.bn_list:
#                bn.train()

        x = self.dense_surf(z)
        x = x.reshape([x.shape[0], self.n_ker[0], self.n_pts_smp[-1]])
        x = self.dense_bn(x)
        x = F.relu(x)

        for i, layer in enumerate(self.dec_layers):
            x = layer(x)
            if i < len(self.dec_layers) - 1:
                # print('i', i)
                x = self.bn_list[i](x)
                x = F.relu(x)
            # if i != len(self.dec_layers) - 1:
            #     # x = self.bn_list[i](x)
            #     x = F.elu(x)

        return x
