import logging

import torch
from tqdm import tqdm
from utils.utils import compute_potential_energy , kabsch_alignment
from openmm import unit 
import numpy as np 

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('Training')

def loss_function(internal_coords_pred, internal_coords_target, mu, logvar,propagator_pred,propagator_target, bond_width_weight=1.0, bond_angles_weight=1.0, torsion_angles_weight=1.0, kl_weight=1.0, propagator_weight=1.0, rescale_transform=None):
    """
    Loss function for the VAE
    """
    # internal_coords_pred, internal_coords_target = [bond_width, bond_angles, torsion_angles]
    # mu, logvar = [mu, logvar]

    bond_width, bond_angles, torsion_angles = rescale_transform(*internal_coords_target[0:3])
    bond_width_pred, bond_angles_pred, torsion_angles_pred = rescale_transform(*internal_coords_pred) #pred[:,:bond_width.shape[1]], pred[:,bond_width.shape[1]:bond_width.shape[1]+bond_angles.shape[1]], pred[:,bond_width.shape[1]+bond_angles.shape[1]:]

    # bond width loss
    bond_width_loss = mse_loss(bond_width_pred, bond_width)

    # bond angles loss
    bond_angles_loss = cosine_loss(bond_angles_pred, bond_angles)

    # torsion angles loss
    torsion_angles_loss = cosine_loss(torsion_angles_pred, torsion_angles)

    if mu is not None:
        # KL divergence for latent space
        kl_div = kl_loss(mu, logvar)
    else:
        # --no_vae flag was passed, so no gaussian latent space is enforced
        kl_div = torch.zeros(1, device=bond_width.device)

    if propagator_pred is not None:
        # Propagator loss
        propagator_loss = latent_loss(propagator_pred, propagator_target)
    else:
        propagator_loss = torch.zeros(1, device=bond_width.device)

    # Total loss , +2 to avoid negative loss
    loss = bond_width_weight * bond_width_loss - bond_angles_weight * bond_angles_loss - torsion_angles_weight * torsion_angles_loss + kl_weight * kl_div + 2.0 + propagator_weight * propagator_loss

    return loss, (bond_width_loss.detach(), 1- bond_angles_loss.detach(), 1-torsion_angles_loss.detach() , kl_div.detach() , propagator_loss.detach())

def kl_loss(mu=None, logstd=None):
    """
    Closed formula of the KL divergence for normal distributions
    """
    MAX_LOGSTD = 10
    logstd =  logstd.clamp(max=MAX_LOGSTD)
    kl_div = -0.5 * torch.mean(torch.sum(1 + 2 * logstd - mu**2 - logstd.exp()**2, dim=1))

    # Limit numeric errors
    kl_div = kl_div.clamp(max=1000)
    return kl_div


def mse_loss(pred, target):
    """
    Reconstruction loss for the VAE
    """
    # pred, target = [bond_width, bond_angles, torsion_angles]
    return torch.nn.functional.mse_loss(pred, target)

def von_Mises_loss(a, b, a_sin=None, b_sin=None):
    """
    :param a:  first angle
    :param b:  second angle
    :return: difference of cosines
    """
    a , b = torch.cos(a) , torch.cos(b)

    if torch.is_tensor(a_sin):
        out = a * b + a_sin * b_sin
    else:
        out = a * b + torch.sqrt(1-a**2 + 1e-5) * torch.sqrt(1-b**2 + 1e-5)
    return out.mean()

def latent_loss(pred, target):
    """
    Propagation loss in the latent space
    pred : (batch_size, pred_seq_len, latent_dim)
    target : (batch_size, pred_seq_len, latent_dim)
    """
    return torch.nn.functional.mse_loss(pred, target)

def periodic_angle_loss( angles):
        """
        Penalizes angles outside the range [-pi, pi]
        Prevents violating invertibility in internal coordinate transforms.
        Computes
            L = (a-pi) ** 2 for a > pi
            L = (a+pi) ** 2 for a < -pi
        and returns the sum over all angles per batch.
        """
        zero = torch.zeros(1, 1, dtype=angles.dtype).to(angles.device)
        positive_loss = torch.sum(torch.where(angles > torch.pi, angles - torch.pi, zero) ** 2, dim=-1)
        negative_loss = torch.sum(torch.where(angles < -torch.pi, angles + torch.pi, zero) ** 2, dim=-1)
        return torch.clamp((positive_loss + negative_loss).mean(), max=1)

def periodic_angles_loss_0_1(scaled_angles):
    """
    Penalizes angles outside the range [0, 1]
    Prevents violating invertibility in internal coordinate transforms.
    Computes
        L = (a-1) ** 2 for a > 1
        L = (a) ** 2 for a < 0
    and returns the sum over all angles per batch.
    """
    zero = torch.zeros(1, 1, dtype=scaled_angles.dtype).to(scaled_angles.device)
    positive_loss = torch.sum(torch.where(scaled_angles > 1, scaled_angles - 1, zero) ** 2, dim=-1)
    negative_loss = torch.sum(torch.where(scaled_angles < 0, scaled_angles, zero) ** 2, dim=-1)
    return (positive_loss + negative_loss).mean()




def cosine_loss(pred, target):
    """
    Cosine loss for the VAE
    """
    # pred, target = [bond_angles, torsion_angles]
    return torch.cos(pred-target).mean()


class AverageMeter():
    def __init__(self, types, unpooled_metrics=False, intervals=1):
        self.types = types
        self.intervals = intervals
        self.count = 0 if intervals == 1 else torch.zeros(len(types), intervals)
        self.acc = {t: torch.zeros(intervals) for t in types}
        self.unpooled_metrics = unpooled_metrics

    def add(self, vals, interval_idx=None):

        if self.intervals == 1:
            self.count += 1 if vals[0].dim() == 0 else len(vals[0])
            for type_idx, v in enumerate(vals):
                self.acc[self.types[type_idx]] += v.sum() if self.unpooled_metrics else v
        else:
            for type_idx, v in enumerate(vals):
                self.count[type_idx].index_add_(0, interval_idx[type_idx], torch.ones(len(v)))
                if not torch.allclose(v, torch.tensor(0.0)):
                    self.acc[self.types[type_idx]].index_add_(0, interval_idx[type_idx], v)

    def summary(self):
        if self.intervals == 1:
            out = {k: v.item() / self.count for k, v in self.acc.items()}
            return out
        else:
            out = {}
            for i in range(self.intervals):
                for type_idx, k in enumerate(self.types):
                    out['int' + str(i) + '_' + k] = (
                            list(self.acc.values())[type_idx][i] / self.count[type_idx][i]).item()
            return out
        

def train_epoch(model, loader, optimizer, loss_fn,device , coordinate_transform, inverse_feature_scaling, args):
    model.train()
    if args.log_rmsd:
        meter = AverageMeter(["loss", "bond_width_loss", "bond_angles_loss", "torsion_angles_loss", "kl_div","propagator_loss", "rmsd"])
    else:
        meter = AverageMeter(["loss", "bond_width_loss", "bond_angles_loss", "torsion_angles_loss", "kl_div", "propagator_loss"])

    for i, (data , *target )in enumerate(tqdm(loader, total=len(loader))):
        # TODO: check if this is needed
        data = data.to(device) 
        # target = [bond_width, bond_angles, torsion_angles , frames[t+tau] ,  potEnergies[t+tau]]
        int_coords_target = [t.to(device=device) for t in target]

    
        optimizer.zero_grad()
        try:
            if args.no_propagator:
                # only train vae, no propagator
                internal_coords_pred, mu, logvar = model(data)
                propagator_pred, latent_target = None , None
            else:
                latent_target,internal_coords_pred,propagator_pred, mu, logvar = model(data)

            loss , (bond_width_loss, bond_angles_loss, torsion_angles_loss , kl_div, propagator_loss) = loss_fn(internal_coords_pred, int_coords_target, mu, logvar, propagator_pred, latent_target)

            if loss.isnan():
                # log nan warning
                logger.warning("Loss is nan")

            else:
                loss.backward()
                optimizer.step()
                if args.log_rmsd:
                    with torch.no_grad():
                        try:
                            if not args.no_propagator:
                                # reshape internal_coords_pred and target from (batch_size, seq_len+1, feature_dim) to (batch_size*(seq_len+1), feature_dim)
                                # otherwise rmsd alignment will fail
                                int_coords_target[0:3] = [t.view(-1, t.shape[-1]) for t in int_coords_target[0:3]]
                                internal_coords_pred = list([p.view(-1, p.shape[-1]) for p in internal_coords_pred])

                                # reshape frames from (batch_size, seq_len+1, 22, 3) to (batch_size*(seq_len+1), 22, 3)
                                int_coords_target[3] = int_coords_target[3].view(-1,22,3)

                            # rescale to coordinate transform representation
                            bond_pred, angle_pred, torsion_pred = inverse_feature_scaling(internal_coords_pred[0], internal_coords_pred[1], internal_coords_pred[2])

                            
                            # get predicted cartesian coordinates, they are not rmsd aligned , pass in frames as reference 
                            pred_pos = coordinate_transform.get_extrinsic_representation(int_coords_target[3],(bond_pred,angle_pred,torsion_pred),data.batch)

                            # rmsd align the predicted cartesian coordinates
                            # compute kabsch transformation for each frame
                            transformations = [ kabsch_alignment(c1,c2) for c1,c2 in zip(pred_pos,int_coords_target[3])]
                            # apply transformation to predicted cartesian coordinates
                            pred_pos = torch.stack([(R.mm(c.T)).T + t for c,(R,t) in zip(pred_pos,transformations)])

                            # compute rmsd for each frame, as frames are now rmsd aligned
                            rmsds = torch.sqrt(torch.mean(((pred_pos - int_coords_target[3])**2).sum(dim=2),dim=1))

                            if args.log_energy:
                                # compute potential energy for each conformation
                                pred_potEnergies = compute_potential_energy(int_coords_target[3], args._compute_single_energy).cpu()
                                logger.info("Mean predicted energy: {:.2f} kcal/mol".format(pred_potEnergies.mean().item()))

                        except Exception as e:
                            # linalg error 
                            logger.warning("RMSD alignment failed")
                            rmsds= torch.tensor([0.0], device=data.x.device)

                    meter.add([loss.cpu().detach(), bond_width_loss.cpu().detach(), bond_angles_loss.cpu().detach(), torsion_angles_loss.cpu().detach() , kl_div.cpu().detach(),propagator_loss.cpu().detach(), rmsds.mean().cpu().detach()])
                else:
                    meter.add([loss.cpu().detach(), bond_width_loss.cpu().detach(), bond_angles_loss.cpu().detach(), torsion_angles_loss.cpu().detach(),propagator_loss.cpu().detach() , kl_div.cpu().detach()])

        except RuntimeError as e:
            if 'out of memory' in str(e):
                logger.warning('Ran out of memory, skipping batch')
                for p in model.parameters():
                    if p.grad is not None:
                        del p.grad  # free some memory
                torch.cuda.empty_cache()
                continue
            elif 'Input mismatch' in str(e):
                logger.warning('Weird torch_cluster error, skipping batch')
                for p in model.parameters():
                    if p.grad is not None:
                        del p.grad  # free some memory
                torch.cuda.empty_cache()
                continue
            else:
                raise e

    return meter.summary()

def test_epoch(model, loader, loss_fn , device , coordinate_transform, inverse_feature_scaling, args):
    model.eval()
    if args.log_rmsd:
        meter = AverageMeter(["loss", "bond_width_loss", "bond_angles_loss", "torsion_angles_loss", "kl_div","propagator_loss", "rmsd"])
    else:
        meter = AverageMeter(["loss", "bond_width_loss", "bond_angles_loss", "torsion_angles_loss", "kl_div", "propagator_loss"])
        

    for i, (data , *target )in enumerate(tqdm(loader, total=len(loader))):
        data = data.to(device)
        # target = [bond_width, bond_angles, torsion_angles , frames[t+tau] ,  potEnergies[t+tau]]
        int_coords_target = [t.to(device=device) for t in target]

        try:
            with torch.no_grad():
                if args.no_propagator:
                    internal_coords_pred, mu, logvar = model(data)
                    propagator_pred, latent_target = None , None
                else:
                    latent_target,internal_coords_pred,propagator_pred, mu, logvar = model(data)
            

            loss , (bond_width_loss, bond_angles_loss, torsion_angles_loss , kl_div, propagator_loss) = loss_fn(internal_coords_pred, int_coords_target, mu, logvar, propagator_pred, latent_target)


            
            if args.log_rmsd:
                with torch.no_grad():
                    try:
                        if not args.no_propagator:
                            # reshape internal_coords_pred and target from (batch_size, seq_len+1, feature_dim) to (batch_size*(seq_len+1), feature_dim)
                            # otherwise rmsd alignment will fail
                            int_coords_target[0:3] = [t.view(-1, t.shape[-1]) for t in int_coords_target[0:3]]
                            internal_coords_pred = list([p.view(-1, p.shape[-1]) for p in internal_coords_pred])

                            # reshape frames from (batch_size, seq_len+1, 22, 3) to (batch_size*(seq_len+1), 22, 3)
                            int_coords_target[3] = int_coords_target[3].view(-1,22,3)

                        # rescale to coordinate transform representation
                        bond_pred, angle_pred, torsion_pred = inverse_feature_scaling(internal_coords_pred[0], internal_coords_pred[1], internal_coords_pred[2])

                        
                        # get predicted cartesian coordinates, they are not rmsd aligned , pass in frames as reference 
                        pred_pos = coordinate_transform.get_extrinsic_representation(int_coords_target[3],(bond_pred,angle_pred,torsion_pred),data.batch)

                        # rmsd align the predicted cartesian coordinates
                        # compute kabsch transformation for each frame
                        transformations = [ kabsch_alignment(c1,c2) for c1,c2 in zip(pred_pos,int_coords_target[3])]
                        # apply transformation to predicted cartesian coordinates
                        pred_pos = torch.stack([(R.mm(c.T)).T + t for c,(R,t) in zip(pred_pos,transformations)])

                        # compute rmsd for each frame, as frames are now rmsd aligned
                        rmsds = torch.sqrt(torch.mean(((pred_pos - int_coords_target[3])**2).sum(dim=2),dim=1))

                        if args.log_energy:
                            # compute potential energy for each conformation
                            pred_potEnergies = compute_potential_energy(int_coords_target[3], args._compute_single_energy).cpu()
                            logger.info("Mean predicted energy: {:.2f} kcal/mol".format(pred_potEnergies.mean().item()))
                            
                    except Exception as e:
                        # linalg error 
                        logger.warning("RMSD alignment failed")
                        rmsds= torch.tensor([0.0], device=data.x.device)

                meter.add([loss.cpu().detach(), bond_width_loss.cpu().detach(), bond_angles_loss.cpu().detach(), torsion_angles_loss.cpu().detach() , kl_div.cpu().detach(),propagator_loss.cpu().detach(), rmsds.mean().cpu().detach()])
            else:
                meter.add([loss.cpu().detach(), bond_width_loss.cpu().detach(), bond_angles_loss.cpu().detach(), torsion_angles_loss.cpu().detach(),propagator_loss.cpu().detach() , kl_div.cpu().detach()])

        except RuntimeError as e:
            if 'out of memory' in str(e):
                logger.warning('Ran out of memory, skipping batch')
                for p in model.parameters():
                    if p.grad is not None:
                        del p.grad  # free some memory
                torch.cuda.empty_cache()
                continue
            elif 'Input mismatch' in str(e):
                logger.warning('Weird torch_cluster error, skipping batch')
                for p in model.parameters():
                    if p.grad is not None:
                        del p.grad  # free some memory
                torch.cuda.empty_cache()
                continue
            else:
                raise e

    return meter.summary()
    
