import os

import torch
import yaml

from models.Decoder import MLPDecoder, UnifiedMLPDecoder
from models.GNNEncoder import GNNE, TensorProductEncoder , SimpleInternalEncoder
from models.VAE import VAE, AE
from models.LaMD import LaMD
from models.Propagator import LinearPropagator, RNNPropagator
import openmm as mm
import openmm.unit as unit
from simtk.openmm.app import ForceField, Topology
from simtk.openmm import LangevinIntegrator, Context
import concurrent.futures
import mdtraj as md

def get_model(args,coordinate_transform=None):
    """get correct configuration of encoder, decoder and propagator"""
    if args.graph_representation == "internal":
        encoder = GNNE(args.feature_size, args.encoder_embedding_size, args.edge_embedding_size, args.latent_embedding_size)
   
    elif args.graph_representation == "simple_internal":
        if coordinate_transform is None:
            raise ValueError("coordinate_transform is None, but is required for simple_internal encoder")
        encoder = SimpleInternalEncoder(coordinate_transform,60, args.latent_embedding_size)

    elif args.graph_representation == "extrinsic":
        encoder = TensorProductEncoder(args.ns,args.nv,args.sh_lmax, args.num_conv_layers , args.batch_norm, args.dropout , args.latent_embedding_size , args.in_edge_features , args.use_set2set_pooling   )
    else:
        raise NotImplementedError("Graph representation not implemented")
    
    decoder = MLPDecoder(args.latent_embedding_size)

    if args.propagator_type == "linear":
        propagator = LinearPropagator(args.latent_embedding_size , args.sequence_length)
    elif args.propagator_type == "lstm":
        propagator = RNNPropagator(args.latent_embedding_size , args.sequence_length ,args.propagator_hidden_size , args.propagator_num_layers , args.propagator_dropout)
    else:
        raise NotImplementedError(f"Propagator {args.propagator_type} not implemented")
    
    if args.no_propagator:
        # only train AE
        if args.no_vae:
            model = AE(encoder, decoder)
        else:
            model = VAE(encoder, decoder)
    else:
        # train VAE with propagator
        model = LaMD(encoder, decoder,propagator, no_vae = args.no_vae)

    return model


class ExponentialMovingAverage:
    """ from https://github.com/yang-song/score_sde_pytorch/blob/main/models/ema.py
    Maintains (exponential) moving average of a set of parameters. """

    def __init__(self, parameters, decay, use_num_updates=True):
        """
        Args:
          parameters: Iterable of `torch.nn.Parameter`; usually the result of
            `model.parameters()`.
          decay: The exponential decay.
          use_num_updates: Whether to use number of updates when computing
            averages.
        """
        if decay < 0.0 or decay > 1.0:
            raise ValueError('Decay must be between 0 and 1')
        self.decay = decay
        self.num_updates = 0 if use_num_updates else None
        self.shadow_params = [p.clone().detach()
                              for p in parameters if p.requires_grad]
        self.collected_params = []

    def update(self, parameters):
        """
        Update currently maintained parameters.
        Call this every time the parameters are updated, such as the result of
        the `optimizer.step()` call.
        Args:
          parameters: Iterable of `torch.nn.Parameter`; usually the same set of
            parameters used to initialize this object.
        """
        decay = self.decay
        if self.num_updates is not None:
            self.num_updates += 1
            decay = min(decay, (1 + self.num_updates) / (10 + self.num_updates))
        one_minus_decay = 1.0 - decay
        with torch.no_grad():
            parameters = [p for p in parameters if p.requires_grad]
            for s_param, param in zip(self.shadow_params, parameters):
                s_param.sub_(one_minus_decay * (s_param - param))

    def copy_to(self, parameters):
        """
        Copy current parameters into given collection of parameters.
        Args:
          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            updated with the stored moving averages.
        """
        parameters = [p for p in parameters if p.requires_grad]
        for s_param, param in zip(self.shadow_params, parameters):
            if param.requires_grad:
                param.data.copy_(s_param.data)

    def store(self, parameters):
        """
        Save the current parameters for restoring later.
        Args:
          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            temporarily stored.
        """
        self.collected_params = [param.clone() for param in parameters]

    def restore(self, parameters):
        """
        Restore the parameters stored with the `store` method.
        Useful to validate the model with EMA parameters without affecting the
        original optimization process. Store the parameters before the
        `copy_to` method. After validation (or model saving), use this to
        restore the former parameters.
        Args:
          parameters: Iterable of `torch.nn.Parameter`; the parameters to be
            updated with the stored parameters.
        """
        for c_param, param in zip(self.collected_params, parameters):
            param.data.copy_(c_param.data)

    def state_dict(self):
        return dict(decay=self.decay, num_updates=self.num_updates,
                    shadow_params=self.shadow_params)

    def load_state_dict(self, state_dict, device):
        self.decay = state_dict['decay']
        self.num_updates = state_dict['num_updates']
        self.shadow_params = [tensor.to(device) for tensor in state_dict['shadow_params']]


def save_yaml_file(path, content):
    assert isinstance(path, str), f'path must be a string, got {path} which is a {type(path)}'
    content = yaml.dump(data=content)
    if '/' in path and os.path.dirname(path) and not os.path.exists(os.path.dirname(path)):
        os.makedirs(os.path.dirname(path))
    with open(path, 'w') as f:
        f.write(content)

def get_optimizer_and_scheduler(args, model, scheduler_mode='min'):
    optimizer = torch.optim.AdamW if args.adamw == 'adamw' else torch.optim.Adam
    optimizer = optimizer(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr,
                                 weight_decay=args.w_decay)

    if args.scheduler == 'plateau':
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=scheduler_mode, factor=0.7,
                                                               patience=args.scheduler_patience, min_lr=args.lr / 100)
    else:
        print('No scheduler')
        scheduler = None

    return optimizer, scheduler


def compute_potential_energy(pos, _compute_potential_energy_single):
    """ Compute potential energy of a batch of conformations using multithreading"""
    
    potEnergies = []

    # use multithreading
    with concurrent.futures.ThreadPoolExecutor() as executor:
        for future in executor.map(_compute_potential_energy_single, [pos[i] for i in range(len(pos))]):
            potEnergies.append(future)
   
    return torch.tensor(potEnergies)



def get_single_conformation_energy_fn(system, temperature):
    """ Compute potential energy of a single conformation, closure used as
    function to compute potential energy of a batch of conformations using multithreading
    """

    temperature = temperature * unit.kelvin
    friction = 1.0 / unit.picoseconds
    integrator = LangevinIntegrator(temperature, friction,1.0 * unit.femtosecond)
    context = Context(system, integrator)


    def _compute_potential_energy_single(pos):
        # transform angstrom to nm
        context.setPositions(pos.numpy() * 0.1 )

        # Get the potential energy
        return context.getState(getEnergy=True).getPotentialEnergy().in_units_of(unit.kilojoule_per_mole)._value

    _compute_potential_energy_single(torch.zeros(22,3))
    return _compute_potential_energy_single


def kabsch_alignment(A, B):
    """
    Taken from https://gist.github.com/bougui505/e392a371f5bab095a3673ea6f4976cc8

    
    See: https://en.wikipedia.org/wiki/Kabsch_algorithm
    2-D or 3-D registration with known correspondences.
    Registration occurs in the zero centered coordinate system, and then
    must be transported back.
        Args:
        -    A: Torch tensor of shape (N,D) -- Point Cloud to Align (source)
        -    B: Torch tensor of shape (N,D) -- Reference Point Cloud (target)
        Returns:
        -    R: optimal rotation
        -    t: optimal translation
    Test on rotation + translation and on rotation + translation + reflection
        >>> A = torch.tensor([[1., 1.], [2., 2.], [1.5, 3.]], dtype=torch.float)
        >>> R0 = torch.tensor([[np.cos(60), -np.sin(60)], [np.sin(60), np.cos(60)]], dtype=torch.float)
        >>> B = (R0.mm(A.T)).T
        >>> t0 = torch.tensor([3., 3.])
        >>> B += t0
        >>> R, t = find_rigid_alignment(A, B)
        >>> A_aligned = (R.mm(A.T)).T + t
        >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
        >>> rmsd
        tensor(3.7064e-07)
        >>> B *= torch.tensor([-1., 1.])
        >>> R, t = find_rigid_alignment(A, B)
        >>> A_aligned = (R.mm(A.T)).T + t
        >>> rmsd = torch.sqrt(((A_aligned - B)**2).sum(axis=1).mean())
        >>> rmsd
        tensor(3.7064e-07)
    """
    a_mean = A.mean(axis=0)
    b_mean = B.mean(axis=0)
    A_c = A - a_mean
    B_c = B - b_mean
    # Covariance matrix
    H = A_c.T.mm(B_c)
    U, S, V = torch.svd(H)
    # Rotation matrix
    R = V.mm(U.T)
    # Translation vector
    t = b_mean[None, :] - R.mm(a_mean[None, :].T).T
    t = t.T
    return R, t.squeeze()


def save_prediction_pdb(model,data, coordinate_transform,args,run_dir, reverse_scaling):
    data = data.to(args.device)
    pred, *_ = model(data)
    pred = reverse_scaling(*pred)
    pred_pos = coordinate_transform.get_extrinsic_representation(data.pos.view(data.batch[-1]+1,-1,3),pred,data.batch)
    mdtraj_topology = md.Topology.from_openmm(args.testsystem.topology)
    true_traj = md.Trajectory(data.pos.view(data.batch[-1]+1,-1,3).detach().cpu().numpy(), mdtraj_topology)
    pred_traj = md.Trajectory(pred_pos.detach().cpu().numpy(), mdtraj_topology)

    # rmsd align to first frame of true traj
    true_traj.superpose(true_traj)
    pred_traj.superpose(true_traj)

    # save pdb 
    true_traj.superpose(true_traj).save_pdb(os.path.join(run_dir,'ground_truth_traj.pdb'))
    pred_traj.superpose(true_traj).save_pdb(os.path.join(run_dir,'pred_traj.pdb'))



