import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal, Normal, kl_divergence, constraints
import numpy as np
import mdtraj as md
from unet_model import UNet
from datetime import datetime
import warnings
import os
from utils import (
    prepare_for_pnerf,
    samples_to_structures,
    get_dihedral_derivatives,
    get_bond_angle_derivatives,
    get_coord_with_O,
)


class VAE(nn.Module):
    def __init__(self, latent_features, encoder_sizes, decoder_sizes, length, bond_lengths, 
                 prior, predict_prior=False, a_start=100, fix_a=False, scale_prior=1.0, 
                 ll=['kappa'], aux_loss=['none', 'mae'], aux_weight_start=[50.0, 50.0], fix_aux_weight=[True, True],
                 allow_negative_lambda=False, steps=1, ll_every_layer=False, superpose=False,
                 coords_ref=None, top_ref=None, sin_cos_in=False, sin_cos_out=False):
        """
        Initialization of the model class
        """
        super(VAE, self).__init__()
        self.latent_features = latent_features
        self.encoder_sizes = encoder_sizes
        self.decoder_sizes = decoder_sizes
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.eps = torch.finfo(torch.float).eps
        self.length = length
        self.upper_ind = torch.triu_indices(length, length, offset=1)
        self.num_dihedrals = length - 3
        self.num_bond_angles = length - 2
        self.num_kappa = 2 * length - 5
        self.bond_lengths = bond_lengths.to(self.device)
        self.mean_only = True
        self.ll = ll
        self.ll_every_layer = ll_every_layer
        self.superpose = superpose
        self.coords_ref = coords_ref
        self.top_ref = top_ref
        self.aux_loss = aux_loss
        self.aux_weight_start = aux_weight_start
        self.fix_aux_weight = fix_aux_weight
        self.allow_negative_lambda = allow_negative_lambda
        self.predict_prior = predict_prior
        self.ancestral_steps = steps
        self.sin_cos_in = sin_cos_in
        self.sin_cos_out = sin_cos_out
        if self.ll_every_layer and self.ancestral_steps>1:
            self.ll_weights = [0.25 + 0.75 / (self.ancestral_steps - 2) for i in range(0, self.ancestral_steps - 2)]
            # self.ll_weights = [1.0] * (self.ancestral_steps - 1)
        
        # Activations
        self.softplus = nn.Softplus()
        
        # Encoder
        in_features = self.num_kappa*2 if self.sin_cos_in else self.num_kappa
        self.encoder = nn.Sequential(
                    nn.Linear(in_features=in_features, out_features=self.encoder_sizes[0]),
                    nn.BatchNorm1d(self.encoder_sizes[0]),
                    nn.LeakyReLU(),
                    nn.Linear(in_features=self.encoder_sizes[0], out_features=self.encoder_sizes[1]),
                    nn.BatchNorm1d(self.encoder_sizes[1]),
                    nn.LeakyReLU(),
                    nn.Linear(in_features=self.encoder_sizes[1], out_features=self.encoder_sizes[2]),
                    nn.BatchNorm1d(self.encoder_sizes[2]),
                    nn.LeakyReLU(),
                    nn.Linear(in_features=self.encoder_sizes[2], out_features=self.latent_features*2)
                    )
        
        # Decoder
        out_features = self.num_kappa*2 if self.sin_cos_out else self.num_kappa
        self.decoder = nn.Sequential(
                    nn.Linear(in_features=self.latent_features, out_features=self.decoder_sizes[0]),
                    nn.BatchNorm1d(self.decoder_sizes[0]),
                    nn.LeakyReLU(),
                    nn.Linear(in_features=self.decoder_sizes[0], out_features=self.decoder_sizes[1]),
                    nn.BatchNorm1d(self.decoder_sizes[1]),
                    nn.LeakyReLU(),
                    nn.Linear(in_features=self.decoder_sizes[1], out_features=self.decoder_sizes[2]),
                    nn.BatchNorm1d(self.decoder_sizes[2]),
                    nn.LeakyReLU(),
                    nn.Linear(in_features=self.decoder_sizes[2], out_features=out_features),
                    )
        if self.sin_cos_out:
            self.tanh = nn.Tanh()
        
        # U-Net
        self.unet = UNet(n_channels=1, n_classes=1, extra_step=False, allow_negative=self.allow_negative_lambda)

        # Data prior prediction (if applicable)
        if self.predict_prior:
            warnings.warn("Predicting prior over kappa! Arguments prior, a_start, fix_a, and scale_prior will be ignored.")
            self.prior = None
            self.data_prior = prior
            self.scale_prior = 1.
            self.a_start = 1.
            self.fix_a = True
            self.conv = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=11, padding=5)
            self.softplus = nn.Softplus()
        else:
            self.prior = prior.to(self.device)
            self.a_start = a_start
            self.fix_a = fix_a
            self.scale_prior = scale_prior
            
        # Data prior weight
        if self.fix_a:
            self.a = torch.tensor([float(self.a_start)], device = self.device)
        else:
            self.a = nn.Parameter(torch.tensor([float(self.a_start)]))
        
        # Weight for auxiliary losses (lambda and fluctuation)
        if self.fix_aux_weight[0]:
            self.fluct_aux_weight = torch.tensor([float(self.aux_weight_start[0])], device = self.device)
        else:
            self.fluct_aux_weight = nn.Parameter(torch.tensor([float(self.aux_weight_start[0])]))
        if self.fix_aux_weight[1]:
            self.lamb_aux_weight = torch.tensor([float(self.aux_weight_start[1])], device = self.device)
        else:
            self.lamb_aux_weight = nn.Parameter(torch.tensor([float(self.aux_weight_start[1])]))

    
    def get_kl_loss(self, mu_z, var_z):
        """
        Calculate KL divergence
        """

        prior = Normal(torch.zeros_like(mu_z), torch.ones_like(mu_z))
        approx_post = Normal(mu_z, torch.sqrt(var_z+self.eps))
        
        KL_z = kl_divergence(approx_post, prior).sum(dim=-1)

        return KL_z.mean()


    def get_lambda_unet(self, structs):
        """
        Get lambda (lagrange multiplier) using a U-Net
        """
        pwds = torch.norm(structs[:, None, :, :]-structs[:, :, None, :], dim=-1)
        scale_factor_unet = self.unet(pwds.unsqueeze(dim=1)).squeeze(dim=1)

        # Average pool diagonal approach
        sf_rows = scale_factor_unet.sum(dim=-2)
        sf_cols = scale_factor_unet.sum(dim=-1)
        sf_diag = scale_factor_unet.diagonal(dim1=-2, dim2=-1)
        sf_average_pool_diag = (sf_rows + sf_cols - sf_diag) / (2 * scale_factor_unet.shape[1] - 1)

        assert sf_average_pool_diag.shape == (structs.shape[0], structs.shape[1]), 'Shape mismatch'
        return sf_average_pool_diag
        

    def get_prec_matrix(self, x, lamb, index=None, return_Cm=False):
        """ 
        Get precision matrix
        """
        dih_derivatives = get_dihedral_derivatives(x)
        ba_derivatives = get_bond_angle_derivatives(x)
        derivatives = torch.cat((dih_derivatives, ba_derivatives), dim=1)
        prec_constr = torch.sum(derivatives[:, :, None, :] * derivatives[:, None, :, :], dim = -1) # "G_m"

        cov_new = torch.einsum('m, mij->ij', (lamb, prec_constr))
        assert cov_new.shape == (self.num_kappa, self.num_kappa), 'Shape mismatch'
        if self.predict_prior:
            prior = torch.diag_embed(self.prior[index])
            assert prior.shape == (self.num_kappa, self.num_kappa), f"Shape mismatch, shape: {prior.shape}"
            cov_new += self.a * self.scale_prior * prior
        else:
            cov_new += self.a * self.scale_prior * self.prior

        assert cov_new.isnan().sum() == 0, f"Precision contains {cov_new.isnan().sum()} nans"
        assert cov_new.isinf().sum() == 0, f"Precision contains {cov_new.isinf().sum()} infs"

        if return_Cm:
            cov_new = torch.linalg.inv(cov_new)

            Cm = cov_new.unsqueeze(dim=0) @ prec_constr
            Cm = Cm.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)

            return cov_new, Cm
        else:
            return cov_new


    def get_structures(self, mu_k):
        """
        Get structures from mean over kappa
        """
        dih, ba = mu_k[:, :self.num_dihedrals], mu_k[:, self.num_dihedrals:]
        assert ba.shape[1] == self.num_bond_angles, 'Shape mismatch'

        di_pNeRF = prepare_for_pnerf(dih, kappa_type="di")
        ba_pNeRF = prepare_for_pnerf(ba, kappa_type="ba")
        structs = samples_to_structures(di_pNeRF, \
                      self.bond_lengths.repeat(1, len(mu_k), 1), ba_pNeRF)

        # TODO: superpose! Preliminary code below, NOT BACKPROPABLE because of mdtraj
        if self.superpose:
            structs = np.concatenate((np.expand_dims(self.coords_ref, axis=0), structs.numpy()), axis=0) # Add reference frame
            structs = md.Trajectory(structs/10, topology=self.top_ref)
            structs = structs.superpose(structs, frame=0).xyz[1:]*10
            structs = torch.from_numpy(structs).float()
                    
        return structs

    
    def get_NLL_AUX(self, mu_k, k_gt, c_gt):
        """
        Get negative loglikelihood using full precision matrix
        Note: k_gt is already centered at zero!
        """
        structs = self.get_structures(mu_k)
        lamb = self.get_lambda_unet(structs)

        NLL = torch.tensor([0.], device=self.device)
        all_Cm = torch.empty_like(lamb) if self.aux_loss[0] != 'none' or 'x' in self.ll else None
        
        for i, (x, l) in enumerate(zip(structs, lamb)):
            if 'kappa' not in self.ll:
                _, all_Cm[i] = self.get_prec_matrix(x, l, i, return_Cm=True)
            elif self.aux_loss[0] != 'none' or 'x' in self.ll:
                prec, all_Cm[i] = self.get_prec_matrix(x, l, i, return_Cm=True)
            else:
                prec = self.get_prec_matrix(x, l, i)
            x, l = None, None

            if 'kappa' in self.ll:
                NLL += self.get_NLL_kappa(prec, k_gt, i)

        if 'x' in self.ll:
            NLL += self.get_NLL_x(all_Cm, structs, c_gt)

        AUX = torch.zeros(2, device=self.device)
        if self.fluct_aux_weight > 0 and self.aux_loss[0] != 'none':
            AUX[0] = self.get_cm_loss(all_Cm)
        if self.lamb_aux_weight > 0 and self.aux_loss[1] != 'none':
            AUX[1] = self.get_lambda_mse_mae(lamb)

        return NLL / self.batch_size, AUX, {"mean":lamb.mean().item(), "min":lamb.min().item(), "max":lamb.max().item()}


    def get_NLL_kappa(self, prec, k_gt, i):
        """
        Get negative loglikelihood in kappa space
        """
        try:
            NLL = -MultivariateNormal(torch.zeros(self.num_kappa, device=self.device), precision_matrix=prec).log_prob(k_gt[i])
        except ValueError:
            if self.predict_prior:
                print(f"Prior: {self.prior[i]}")
            else:
                print(f"Prior: {self.prior}")
            print(f"Precision matrix diagonal: {prec.diag()}")
            print(f"Precision matrix: {prec}")
            raise
        
        return NLL


    def get_NLL_x(self, Cm, pred_mean, c_gt):
        """
        Get negative loglikelihood in Eucledian space
        """
        assert pred_mean.shape == c_gt.shape, "Shape mismatch"
        assert sum(~constraints.positive.check(Cm.flatten()+self.eps))==0, f"Cm not positive, Cm_min:{Cm.flatten().min()}, Cm_max:{Cm.flatten().max()}, Cm:{Cm}"
        NLL = -Normal(pred_mean, (Cm+self.eps).sqrt().unsqueeze(dim=-1)).log_prob(c_gt)
        NLL = NLL.sum(dim=2).mean(dim=1)

        return NLL.sum(dim=0)


    def get_cm_loss(self, Cm):
        """
        Auxiliary loss (MSE) based on calculated contraints (C_m)
        """
        if self.aux_loss[0] == 'cm_prior':
            if not hasattr(self, 'fluct_prior'):
                raise AttributeError("Missing fluctuation prior")
            AUX = ((Cm - self.fluct_prior.unsqueeze(dim=0))**2).sum(dim=-1)
        elif self.aux_loss[0] == 'cm_low':
            AUX = (Cm**2).sum(dim=-1)
        else:
            raise NotImplementedError

        return AUX.mean(dim=0)


    def get_lambda_mse_mae(self, lamb):
        """
        Auxiliary MSE/MAE loss on inverse lambdas to reduce fluctuations
        """
        if self.aux_loss[1] == 'mse':
            AUX = (1/(lamb**2+self.eps)).mean(dim=-1)
        elif self.aux_loss[1] == 'mae':
            if self.allow_negative_lambda:
                AUX = (1/((lamb**2)+self.eps)).sqrt().mean(dim=1)
            else:
                AUX = (1/(lamb+self.eps)).mean(dim=1)
        else:
            raise NotImplementedError

        return AUX.mean(dim=0)
    
    
    def get_recon_loss(self, mu_k, k_gt, c_gt):
        """
        Get reconstruction loss. 
        """
        
        k_gt = torch.atan2(torch.sin(k_gt-mu_k), torch.cos(k_gt-mu_k))

        if self.mean_only:
            NLL = -Normal(torch.zeros_like(mu_k[0]), torch.ones_like(mu_k[0])).\
                  log_prob(k_gt).sum(dim=1).mean(dim=0)
            lamb = None
            AUX = torch.tensor([0.0, 0.0], device =self.device)
        else:
            NLL, AUX, lamb = self.get_NLL_AUX(mu_k, k_gt, c_gt)
                   
        return NLL, AUX, lamb


    def sample_batches(self, num_samples_current, batch_size):
        """
        Get the number of batches for sampling
        """
        if batch_size is None:
            batches = 1
        elif num_samples_current % batch_size == 0:
            batches = num_samples_current // batch_size
        else:
            batches = num_samples_current // batch_size + 1
        return batches


    def lamb_in_batches(self, structs, batches, batch_size):
        """
        Get predicted lambda in batches
        """
        lamb = []
        for i in range(batches):
            end = structs.shape[0] if (i == batches-1) else (i+1)*batch_size
            lamb.append(self.get_lambda_unet(structs[i*batch_size:end]).cpu())
            assert lamb[i].shape[0] <= batch_size, 'Bug in batch size'
        lamb = torch.vstack(lamb)
        assert lamb.shape == (structs.shape[0], structs.shape[1]), f'Shape mismatch, shape: {lamb.shape}'

        return lamb


    def fluctuation_step(self, structs, lamb, mu_k, num_samples_k):
        """
        Take one fluctation stepp (without calculating loss)
        """
        kappa = []
        prec_matrices = []
        for i, (x, l, k) in enumerate(zip(structs, lamb, mu_k)):
            prec = self.get_prec_matrix(x.to(self.device), l.to(self.device), i)
            MN = MultivariateNormal(torch.zeros(self.num_kappa, device=self.device), precision_matrix=prec)
            prec_matrices.append(prec.cpu())
            prec = None
            k = k.unsqueeze(0) + MN.sample(torch.Size([num_samples_k]))
            kappa.append(torch.atan2(torch.sin(k), torch.cos(k)))

        return torch.vstack(kappa), torch.stack(prec_matrices)


    def sample(self, num_samples_z, num_samples_k, topology = None, model_name = None, batch_size=None, eval=True):
        """
        Sample from prior
        num_samples_z is for top layer
        num_samples_k is for each step
        """
        if eval:
            self.eval()

        with torch.no_grad():
            z_prior = Normal(torch.zeros(self.latent_features, device = self.device), 
                            torch.ones(self.latent_features, device = self.device))
            z_samples = z_prior.sample(torch.Size([num_samples_z]))

            for i in range(self.steps + 1):
                if i == 0:
                    mu_k = self.decode(z_samples) # num_samples_z x num_kappa
                else:
                    batches = self.sample_batches(structs.shape[0], batch_size)
                    lamb = self.lamb_in_batches(structs, batches, batch_size)
                    structs = structs.cpu()
                    if i != self.steps:
                        mu_k, _ = self.fluctuation_step(structs, lamb, mu_k, num_samples_k)
                    else:
                        mu_k, prec = self.fluctuation_step(structs, lamb, mu_k, num_samples_k)
                        prec = prec.cpu()
                        assert prec.shape == (num_samples_z * (num_samples_k**(i-1)), self.num_kappa, self.num_kappa), \
                                              f"Shape mismatch, shape: {prec.shape} instead of \
                                              {(num_samples_z * (num_samples_k**i), self.num_kappa, self.num_kappa)}"
                
                structs = self.get_structures(mu_k)
                assert structs.shape[0] == num_samples_z * (num_samples_k**i), \
                f'Shape mismatch, step {i}, shape: {structs.shape} (expected len {num_samples_z * (i * num_samples_k)})'

                if self.predict_prior and i != self.steps:
                    self.prior = self.get_prior(mu_k)

                # Save sampled structures at all steps except z
                if topology is not None and i != 0:
                    if not os.path.isdir("./pdb_files"):
                        os.makedirs("./pdb_files")
                    traj = md.Trajectory(get_coord_with_O(structs, topology)/10, topology=topology["topology"])
                    model_name_save = "" if model_name is None else model_name + "_"
                    traj.save_pdb(f"pdb_files/samples_withO_step{i}of{self.steps}_{model_name_save}{datetime.now().strftime('%d-%m-%Y_%H:%M:%S')}.pdb")

        return {'dihedrals':mu_k[:, :self.num_dihedrals].cpu(), 'bond_angles':mu_k[:, self.num_dihedrals:].cpu(), 
                'structures':structs.cpu(), 'precision_matrices':prec, 'z_samples':z_samples.cpu()}


    def sample_Cm(self, num_samples, eval=True):
        """
        Get constraints Cm
        """
        if eval:
            self.eval()

        with torch.no_grad():
            z_prior = Normal(torch.zeros(self.latent_features, device = self.device), 
                            torch.ones(self.latent_features, device = self.device))
            z_samples = z_prior.sample(torch.Size([num_samples]))

            for i in range(self.steps + 1):
                if i == 0:
                    mu_k = self.decode(z_samples) # num_samples_z x num_kappa
                else:
                    lamb = self.get_lambda_unet(structs).cpu()
                    assert lamb.shape == (num_samples, structs.shape[1]), f'Shape mismatch, shape: {lamb.shape}'
                    structs = structs.cpu()
                    if i < self.steps:
                        mu_k, _ = self.fluctuation_step(structs, lamb, mu_k, 1)
                    else:
                        all_Cm = torch.empty_like(lamb)
                        for j, (x, l) in enumerate(zip(structs, lamb)):
                            _, all_Cm[j] = self.get_prec_matrix(x.to(self.device), l.to(self.device), j, return_Cm=True)
                
                structs = self.get_structures(mu_k)
                assert structs.shape[0] == num_samples, 'Shape mismatch'

                if self.predict_prior and i != self.steps:
                    self.prior = self.get_prior(mu_k)

        return all_Cm, lamb

    
    def encode(self, k):
        """
        Encode input
        """
        kshape = k.shape
        if self.sin_cos_in:
            k = torch.flatten(torch.stack((torch.cos(k), torch.sin(k)), dim=2), start_dim=1, end_dim=2) # interleave cos and sin
            assert k.shape == (kshape[0], kshape[1]*2), f"Shape mismatch: k.shape = {k.shape} instad of {(kshape[0], kshape[1]*2)}"
        mu_z, var_z = torch.chunk(self.encoder(k), 2, dim=-1)
        var_z = self.softplus(var_z)
        z = Normal(mu_z, torch.sqrt(var_z+self.eps)).rsample()
        
        return mu_z, var_z, z


    def decode(self, z):
        """
        Decode from latent space
        """
        mu_k = self.decoder(z)
        if self.sin_cos_out:
            mu_k = self.tanh(mu_k) # Between -1 and 1
            mu_k = torch.atan2(mu_k[:, 1::2], mu_k[:, 0::2]) # [-pi, pi]
            assert (mu_k.min() >= -np.pi) and (mu_k.max() <= np.pi), "invalid dihedral (outside [-pi, pi])"
        else:
            mu_k = mu_k % (2 * np.pi) - np.pi # modulo always gives positive -> result: [-pi, pi]
        
        return mu_k

    def get_prior(self, mu_k):
        """
        Predict prior
        """
        prior = self.conv(mu_k.unsqueeze(dim=1)).squeeze()
        assert mu_k.shape == prior.shape, f"Shape mismatch, shapes: {mu_k.shape} and {prior.shape}"
        prior = self.softplus(prior) + 1.0 # Add 1 to get at least 1?

        return prior

    
    def forward(self, k, c, only_outputs=False): # TODO: clean up
        """
        Forward step of the model
        """
        self.steps = 1 if self.mean_only else self.ancestral_steps
        k_in = k.clone()
        self.batch_size = len(k_in)

        mu_z, var_z, z = self.encode(k_in)

        NLL_steps = torch.tensor([0.], device=self.device) # TODO: check NLL every layer
        AUX_steps = torch.zeros(2, device=self.device)
        for s in range(self.steps):
            if s == 0:
                mu_k = self.decode(z)
                z = z.cpu()
                KL = self.get_kl_loss(mu_z, var_z)

            elif s < self.steps-1:
                structs = self.get_structures(mu_k)
                lamb = self.get_lambda_unet(structs)

                all_Cm = torch.empty_like(lamb) if (self.aux_loss[0] != 'none' or ('x' in self.ll and self.ll_every_layer)) else None
                for i, (x, l) in enumerate(zip(structs, lamb)):
                    if self.aux_loss[0] != 'none' or ('x' in self.ll and self.ll_every_layer):
                        prec, all_Cm[i] = self.get_prec_matrix(x, l, i, return_Cm=True)
                    else:
                        prec = self.get_prec_matrix(x, l, i)
                    mu_k += MultivariateNormal(torch.zeros(self.num_kappa, device=self.device), precision_matrix=prec).rsample()
                    mu_k = torch.atan2(torch.sin(mu_k), torch.cos(mu_k))

                    if ('kappa' in self.ll) and (self.ll_every_layer) and (not self.mean_only):
                        NLL_steps += self.ll_weights[s-1] * self.get_NLL_kappa(prec, k, i)

                prec = None

                if not self.mean_only:
                    if ('x' in self.ll) and (self.ll_every_layer) and (c is not None):
                        NLL_steps += self.ll_weights[s-1] * self.get_NLL_x(all_Cm, structs, c)
                    if self.fluct_aux_weight > 0 and self.aux_loss[0] != 'none':
                        AUX_steps[0] += self.get_cm_loss(all_Cm)
                    if self.lamb_aux_weight > 0 and self.aux_loss[1] != 'none':
                        AUX_steps[1] += self.get_lambda_mse_mae(lamb)

                    if (self.ll_every_layer) and ('x' not in self.ll):
                        raise NotImplementedError("LL at every layer not implemented for kappa")

                structs, lamb = None, None

            if self.predict_prior and s != self.steps-1:
                self.prior = self.get_prior(mu_k)

            if ((s == self.steps-1) or (self.steps == 1)) and ((c is not None) or ('x' not in self.ll)):
                NLL_final, AUX_final, lamb = self.get_recon_loss(mu_k, k, c)

        if only_outputs:
            NLL, KL, AUX, lamb = None, None, None, None
        else:
            AUX = AUX_final if (self.mean_only or self.aux_loss[0]=='cm_prior') else \
                    (AUX_steps + AUX_final) #/ torch.tensor([self.steps, self.steps], device=self.device)

            NLL = NLL_final if self.mean_only else \
                    (NLL_final + NLL_steps / self.batch_size) #/ self.steps

        return {"z":z, "kappa":mu_k}, NLL, KL, AUX, lamb # TODO: multiple layers of latent space? i.e. intermediate kappa/structures. Now: saving structures but not making plots.
    
    

        