from collections import defaultdict
import torch
from data import so3_utils
from data import utils as du
from scipy.spatial.transform import Rotation
from data import all_atom
import copy
from scipy.optimize import linear_sum_assignment
from torch import autograd
import numpy as np
from experiments import potentials
import functools as fn
from motif_scaffolding import twisting


def _centered_gaussian(num_batch, num_res, device):
    noise = torch.randn(num_batch, num_res, 3, device=device)
    return noise - torch.mean(noise, dim=-2, keepdims=True)

def _uniform_so3(num_batch, num_res, device):
    return torch.tensor(
        Rotation.random(num_batch*num_res).as_matrix(),
        device=device,
        dtype=torch.float32,
    ).reshape(num_batch, num_res, 3, 3)

def _trans_diffuse_mask(trans_t, trans_1, diffuse_mask):
    return trans_t * diffuse_mask[..., None] + trans_1 * (1 - diffuse_mask[..., None])

def _rots_diffuse_mask(rotmats_t, rotmats_1, diffuse_mask):
    return (
        rotmats_t * diffuse_mask[..., None, None]
        + rotmats_1 * (1 - diffuse_mask[..., None, None])
    )


class Interpolant:

    def __init__(self, cfg):
        self._cfg = cfg
        self._rots_cfg = cfg.rots
        self._trans_cfg = cfg.trans
        self._sample_cfg = cfg.sampling
        self._igso3 = None

        self._trans_potential = None
        if 'potential' in self._trans_cfg:
            if self._trans_cfg.potential == 'rog':
                self._trans_potential = fn.partial(
                    potentials.rog,
                    weight=self._trans_cfg.rog.weight,
                    cutoff=self._trans_cfg.rog.cutoff,
                )
            elif self._trans_cfg.potential is not None:
                raise ValueError(
                    f'Unknown trans potential {self._trans_cfg.potential}')

    @property
    def igso3(self):
        if self._igso3 is None:
            sigma_grid = torch.linspace(0.1, 1.5, 1000)
            self._igso3 = so3_utils.SampleIGSO3(
                1000, sigma_grid, cache_dir='.cache')
        return self._igso3

    def set_device(self, device):
        self._device = device

    def sample_t(self, num_batch):
        t = torch.rand(num_batch, device=self._device)
        return t * (1 - 2*self._cfg.min_t) + self._cfg.min_t

    def _corrupt_trans(self, trans_1, t, res_mask, diffuse_mask):
        trans_nm_0 = _centered_gaussian(*res_mask.shape, self._device)
        trans_0 = trans_nm_0 * du.NM_TO_ANG_SCALE
        if self._trans_cfg.batch_ot:
            trans_0 = self._batch_ot(trans_0, trans_1, diffuse_mask)
        if self._trans_cfg.train_schedule == 'linear':
            trans_t = (1 - t[..., None]) * trans_0 + t[..., None] * trans_1
        elif self._trans_cfg.train_schedule == 'vpsde':
            # t (B,1)
            # trans_0 (B, N, 3)
            bmin = self._trans_cfg.vpsde_bmin
            bmax = self._trans_cfg.vpsde_bmax
            alpha_t = torch.exp(- bmin * (1-t) - 0.5 * (1-t)**2 * (bmax - bmin)) # (B,1)
            trans_t = torch.sqrt(alpha_t[..., None]) * trans_1 + torch.sqrt(1 - alpha_t[..., None]) * trans_0
        else:
            raise ValueError(
                f'Unknown trans schedule {self._trans_cfg.train_schedule}')
        trans_t = _trans_diffuse_mask(trans_t, trans_1, diffuse_mask)
        return trans_t * res_mask[..., None]
    
    def _batch_ot(self, trans_0, trans_1, res_mask):
        num_batch, num_res = trans_0.shape[:2]
        noise_idx, gt_idx = torch.where(
            torch.ones(num_batch, num_batch))
        batch_nm_0 = trans_0[noise_idx]
        batch_nm_1 = trans_1[gt_idx]
        batch_mask = res_mask[gt_idx]
        aligned_nm_0, aligned_nm_1, _ = du.batch_align_structures(
            batch_nm_0, batch_nm_1, mask=batch_mask
        ) 
        aligned_nm_0 = aligned_nm_0.reshape(num_batch, num_batch, num_res, 3)
        aligned_nm_1 = aligned_nm_1.reshape(num_batch, num_batch, num_res, 3)
        
        # Compute cost matrix of aligned noise to ground truth
        batch_mask = batch_mask.reshape(num_batch, num_batch, num_res)
        cost_matrix = torch.sum(
            torch.linalg.norm(aligned_nm_0 - aligned_nm_1, dim=-1), dim=-1
        ) / torch.sum(batch_mask, dim=-1)
        noise_perm, gt_perm = linear_sum_assignment(du.to_numpy(cost_matrix))
        return aligned_nm_0[(tuple(gt_perm), tuple(noise_perm))]
    
    def _corrupt_rotmats(self, rotmats_1, t, res_mask, diffuse_mask):
        num_batch, num_res = res_mask.shape
        noisy_rotmats = self.igso3.sample(
            torch.tensor([1.5]),
            num_batch*num_res
        ).to(self._device)
        noisy_rotmats = noisy_rotmats.reshape(num_batch, num_res, 3, 3)
        rotmats_0 = torch.einsum(
            "...ij,...jk->...ik", rotmats_1, noisy_rotmats)
        
        so3_schedule = self._rots_cfg.train_schedule
        if so3_schedule == 'exp':
            so3_t = 1 - torch.exp(-t*self._rots_cfg.exp_rate)
        elif so3_schedule == 'linear':
            so3_t = t
        else:
            raise ValueError(f'Invalid schedule: {so3_schedule}')
        rotmats_t = so3_utils.geodesic_t(so3_t[..., None], rotmats_1, rotmats_0)
        identity = torch.eye(3, device=self._device)
        rotmats_t = (
            rotmats_t * res_mask[..., None, None]
            + identity[None, None] * (1 - res_mask[..., None, None])
        )
        return _rots_diffuse_mask(rotmats_t, rotmats_1, diffuse_mask)

    def corrupt_batch(self, batch):
        noisy_batch = copy.deepcopy(batch)

        # [B, N, 3]
        trans_1 = batch['trans_1']  # Angstrom

        # [B, N, 3, 3]
        rotmats_1 = batch['rotmats_1']

        # [B, N]
        res_mask = batch['res_mask']
        diffuse_mask = batch['diffuse_mask']
        num_batch, _ = diffuse_mask.shape

        # [B, 1]
        if self._cfg.separate_t:
            if self._cfg.hierarchical_t:
                max_t = torch.rand(num_batch, device=self._device) * (1 - self._cfg.min_t)
                so3_t = self._cfg.min_t + torch.rand(num_batch, device=self._device) * (max_t - self._cfg.min_t)
                r3_t = self._cfg.min_t + torch.rand(num_batch, device=self._device) * (max_t - self._cfg.min_t)
                so3_t = so3_t[:, None]
                r3_t = r3_t[:, None]
            else:
                so3_t = self.sample_t(num_batch)[:, None]
                r3_t = self.sample_t(num_batch)[:, None]
        else:
            t = self.sample_t(num_batch)[:, None]
            so3_t = t
            r3_t = t
        noisy_batch['so3_t'] = so3_t
        noisy_batch['r3_t'] = r3_t

        # Apply corruptions
        if self._trans_cfg.corrupt:
            trans_t = self._corrupt_trans(
                trans_1, r3_t, res_mask, diffuse_mask)
        else:
            trans_t = trans_1
        if torch.any(torch.isnan(trans_t)):
            raise ValueError('NaN in trans_t during corruption')
        noisy_batch['trans_t'] = trans_t

        if self._rots_cfg.corrupt:
            rotmats_t = self._corrupt_rotmats(rotmats_1, so3_t, res_mask, diffuse_mask)
        else:
            rotmats_t = rotmats_1
        if torch.any(torch.isnan(rotmats_t)):
            raise ValueError('NaN in rotmats_t during corruption')
        noisy_batch['rotmats_t'] = rotmats_t
        return noisy_batch
    
    def rot_sample_kappa(self, t):
        if self._rots_cfg.sample_schedule == 'exp':
            return 1 - torch.exp(-t*self._rots_cfg.exp_rate)
        elif self._rots_cfg.sample_schedule == 'linear':
            return t
        else:
            raise ValueError(
                f'Invalid schedule: {self._rots_cfg.sample_schedule}')

    def _trans_vector_field(self, t, trans_1, trans_t):
        if self._trans_cfg.sample_schedule == 'linear':
            trans_vf = (trans_1 - trans_t) / (1 - t)
        elif self._trans_cfg.sample_schedule == 'vpsde':
            bmin = self._trans_cfg.vpsde_bmin
            bmax = self._trans_cfg.vpsde_bmax
            bt = bmin + (bmax - bmin) * (1-t) # scalar
            alpha_t = torch.exp(- bmin * (1-t) - 0.5 * (1-t)**2 * (bmax - bmin)) # scalar
            trans_vf = 0.5 * bt * trans_t + \
                0.5 * bt * (torch.sqrt(alpha_t) * trans_1 - trans_t) / (1 - alpha_t)
        else:
            raise ValueError(
                f'Invalid sample schedule: {self._trans_cfg.sample_schedule}'
            )
        return trans_vf

    def _trans_euler_step(self, d_t, t, trans_1, trans_t):
        # TODO: Add in temperature
        # TODO: Add in SDE
        assert d_t > 0

        # TODO implement the ability to switch between schedules
        #TODO: fix this
        # assert self._trans_cfg.sample_schedule == self._trans_cfg.train_schedule
        trans_vf = self._trans_vector_field(t, trans_1, trans_t)

        return trans_t + trans_vf * d_t

    def _rots_euler_step(self, d_t, t, rotmats_1, rotmats_t):
        if self._rots_cfg.sample_schedule == 'linear':
            scaling = 1 / (1 - t)
        elif self._rots_cfg.sample_schedule == 'exp':
            scaling = self._rots_cfg.exp_rate
        else:
            raise ValueError(
                f'Unknown sample schedule {self._rots_cfg.sample_schedule}')
        # TODO: Add in SDE.
        return so3_utils.geodesic_t(
            scaling * d_t, rotmats_1, rotmats_t)

    def sample(
            self,
            num_batch,
            num_res,
            model,
            num_timesteps=None,
            trans_potential=None,
            trans_0=None,
            rotmats_0=None,
            trans_1=None,
            rotmats_1=None,
            diffuse_mask=None,
            chain_idx=None,
            res_idx=None,
            verbose=False,
            t_nn=None,
        ):

        res_mask = torch.ones(num_batch, num_res, device=self._device)

        if trans_potential is None:
            trans_potential = self._trans_potential

        # Set-up initial prior samples
        if trans_0 is None:
            trans_0 = _centered_gaussian(
                num_batch, num_res, self._device) * du.NM_TO_ANG_SCALE
        if rotmats_0 is None:
            rotmats_0 = _uniform_so3(num_batch, num_res, self._device)
        if res_idx is None:
            res_idx = torch.arange(
                num_res,
                device=self._device,
                dtype=torch.float32)[None].repeat(num_batch, 1)
        batch = {
            'res_mask': res_mask,
            'diffuse_mask': res_mask,
            'chain_idx': res_mask if chain_idx is None else chain_idx,
            'res_idx': res_idx 
        }

        motif_scaffolding = False
        if diffuse_mask is not None and trans_1 is not None and rotmats_1 is not None:
            motif_scaffolding = True
            motif_mask = ~diffuse_mask.bool().squeeze(0)
        else:
            motif_mask = None
        if motif_scaffolding and not self._cfg.twisting.use: # amortisation
            diffuse_mask = diffuse_mask.expand(num_batch, -1) # shape = (B, num_residue)
            batch['diffuse_mask'] = diffuse_mask
            rotmats_0 = _rots_diffuse_mask(rotmats_0, rotmats_1, diffuse_mask)
            trans_0 = _trans_diffuse_mask(trans_0, trans_1, diffuse_mask)
            if torch.isnan(trans_0).any():
                raise ValueError('NaN detected in trans_0')

        logs_traj = defaultdict(list)
        if motif_scaffolding and self._cfg.twisting.use: # sampling / guidance
            assert trans_1.shape[0] == 1 # assume only one motif
            motif_locations = torch.nonzero(motif_mask).squeeze().tolist()
            true_motif_locations, motif_segments_length = twisting.find_ranges_and_lengths(motif_locations)

            # Marginalise both rotation and motif location
            assert len(motif_mask.shape) == 1
            trans_motif = trans_1[:, motif_mask]  # [1, motif_res, 3]
            R_motif = rotmats_1[:, motif_mask]  # [1, motif_res, 3, 3]
            num_res = trans_1.shape[-2]
            with torch.inference_mode(False):
                motif_locations = true_motif_locations if self._cfg.twisting.motif_loc else None
                F, motif_locations = twisting.motif_offsets_and_rots_vec_F(num_res, motif_segments_length, motif_locations=motif_locations, num_rots=self._cfg.twisting.num_rots, align=self._cfg.twisting.align, scale=self._cfg.twisting.scale_rots, trans_motif=trans_motif, R_motif=R_motif, max_offsets=self._cfg.twisting.max_offsets, device=self._device, dtype=torch.float64, return_rots=False)

        if motif_mask is not None and len(motif_mask.shape) == 1:
            motif_mask = motif_mask[None].expand((num_batch, -1))

        # Set-up time
        if num_timesteps is None:
            num_timesteps = self._sample_cfg.num_timesteps
        ts = torch.linspace(self._cfg.min_t, 1.0, num_timesteps)
        t_1 = ts[0]

        prot_traj = [(trans_0, rotmats_0)]
        clean_traj = []
        motif_mask_traj = [motif_mask]
        for i, t_2 in enumerate(ts[1:]):
            # Run model.
            trans_t_1, rotmats_t_1 = prot_traj[-1]
            if self._trans_cfg.corrupt:
                batch['trans_t'] = trans_t_1
            else:
                if trans_1 is None:
                    raise ValueError('Must provide trans_1 if not corrupting.')
                batch['trans_t'] = trans_1
            if self._rots_cfg.corrupt:
                batch['rotmats_t'] = rotmats_t_1
            else:
                if rotmats_1 is None:
                    raise ValueError('Must provide rotmats_1 if not corrupting.')
                batch['rotmats_t'] = rotmats_1
            batch['t'] = torch.ones((num_batch, 1), device=self._device) * t_1

            if t_nn is not None:
                batch['r3_t'], batch['so3_t'] = torch.split(t_nn(batch['t']), -1)
            else:
                if self._cfg.provide_kappa:
                    batch['so3_t'] = self.rot_sample_kappa(batch['t'])
                else:
                    batch['so3_t'] = batch['t']
                batch['r3_t'] = batch['t']
            d_t = t_2 - t_1

            use_twisting = motif_scaffolding and self._cfg.twisting.use and t_1 >= self._cfg.twisting.t_min

            if use_twisting: # Reconstruction guidance
                with torch.inference_mode(False):
                    batch, Log_delta_R, delta_x = twisting.perturbations_for_grad(batch)
                    model_out = model(batch)
                    t = batch['r3_t']
                    trans_t_1, rotmats_t_1, motif_mask, logs_traj = self.guidance(F, motif_locations, motif_mask, motif_segments_length, true_motif_locations, trans_t_1, rotmats_t_1, model_out, rotmats_1, trans_1, Log_delta_R, delta_x, t, d_t, logs_traj)

            else:
                with torch.no_grad():
                    model_out = model(batch)

            # Process model output.
            pred_trans_1 = model_out['pred_trans']
            pred_rotmats_1 = model_out['pred_rotmats']
            clean_traj.append(
                (pred_trans_1.detach().cpu(), pred_rotmats_1.detach().cpu())
            )
            if self._cfg.self_condition:
                if motif_scaffolding and not self._cfg.twisting.use:
                    batch['trans_sc'] = (
                        pred_trans_1 * diffuse_mask[..., None]
                        + trans_1 * (1 - diffuse_mask[..., None])
                    )
                else:
                    batch['trans_sc'] = pred_trans_1

            # Take reverse step
            
            trans_t_2 = self._trans_euler_step(
                d_t, t_1, pred_trans_1, trans_t_1)
            if trans_potential is not None:
                with torch.inference_mode(False):
                    grad_pred_trans_1 = pred_trans_1.clone().detach().requires_grad_(True)
                    pred_trans_potential = autograd.grad(outputs=trans_potential(grad_pred_trans_1), inputs=grad_pred_trans_1)[0]
                if self._trans_cfg.potential_t_scaling:
                    trans_t_2 -= t_1 / (1 - t_1) * pred_trans_potential * d_t
                else:
                    trans_t_2 -= pred_trans_potential * d_t
            rotmats_t_2 = self._rots_euler_step(
                d_t, t_1, pred_rotmats_1, rotmats_t_1)
            if motif_scaffolding and not self._cfg.twisting.use:
                trans_t_2 = _trans_diffuse_mask(trans_t_2, trans_1, diffuse_mask)
                rotmats_t_2 = _rots_diffuse_mask(rotmats_t_2, rotmats_1, diffuse_mask)

            prot_traj.append((trans_t_2, rotmats_t_2))
            motif_mask_traj.append(motif_mask)
            t_1 = t_2

        # We only integrated to min_t, so need to make a final step
        t_1 = ts[-1]
        trans_t_1, rotmats_t_1 = prot_traj[-1]
        if self._trans_cfg.corrupt:
            batch['trans_t'] = trans_t_1
        else:
            if trans_1 is None:
                raise ValueError('Must provide trans_1 if not corrupting.')
            batch['trans_t'] = trans_1
        if self._rots_cfg.corrupt:
            batch['rotmats_t'] = rotmats_t_1
        else:
            if rotmats_1 is None:
                raise ValueError('Must provide rotmats_1 if not corrupting.')
            batch['rotmats_t'] = rotmats_1
        batch['t'] = torch.ones((num_batch, 1), device=self._device) * t_1
        with torch.no_grad():
            model_out = model(batch)
        pred_trans_1 = model_out['pred_trans']
        pred_rotmats_1 = model_out['pred_rotmats']
        clean_traj.append(
            (pred_trans_1.detach().cpu(), pred_rotmats_1.detach().cpu())
        )
        prot_traj.append((pred_trans_1, pred_rotmats_1))

        # Convert trajectories to atom37.
        atom37_traj = all_atom.transrot_to_atom37(prot_traj, res_mask)
        clean_atom37_traj = all_atom.transrot_to_atom37(clean_traj, res_mask)
        return atom37_traj, clean_atom37_traj, clean_traj, motif_mask_traj

    def guidance(self, F, motif_locations, true_motif_mask, motif_segments_length, true_motif_locations, trans_t, rotmats_t, model_out, rotmats_1, trans_1, Log_delta_R, delta_x, t, d_t, logs_traj):
        with torch.no_grad():
            t_trans = t_so3 = t
            scale_t_trans = ((1 - t_trans) / t_trans)
            scale_t_rot = ((1 - t_so3) / t_so3)
            if int(self._cfg.twisting.r_t) == 1:
                #NOTE: rt^2 = \sigma_t^2 / (\alpha_t^2 + \sigma_t^2)
                rt_sq_trans = (1 - t_trans) ** 2 / (t_trans ** 2 + (1 - t_trans) ** 2)
                rt_sq_rot = (1 - t_so3) ** 2 / (t_so3 ** 2 + (1 - t_so3) ** 2)
            # elif int(self._cfg.twisting.r_t) == 2:
            #     #NOTE: rt^2 = \sigma_t^2 / (1 + \sigma_t^2) from ΠGDM
            #     rt_sq_trans = (1 - t_trans) ** 2 / (1 + (1 - t_trans) ** 2)
            #     rt_sq_rot = (1 - t_so3) ** 2 / (1 + (1 - t_so3) ** 2)
            elif int(self._cfg.twisting.r_t) == 3:
                #NOTE: rt^2 = \sigma_t^2
                rt_sq_trans = (1 - t_trans) ** 2
                rt_sq_rot = (1 - t_so3) ** 2
            else:
                raise NotImplementedError(self._cfg.twisting.r_t)

            rt_sq_trans += self._cfg.twisting.obs_noise ** 2
            rt_sq_rot += self._cfg.twisting.obs_noise ** 2

            trans_scale_t = self._cfg.twisting.scale * (scale_t_trans / rt_sq_trans)[:, None]
            rot_scale_t = self._cfg.twisting.scale * (scale_t_rot / rt_sq_rot)[:, None, None]
            rt_sq_trans, rt_sq_rot = torch.ones_like(rt_sq_trans), torch.ones_like(rt_sq_rot)
            # trans_scale_t = self._cfg.twisting.scale * scale_t_trans[:, None]
            # rot_scale_t = self._cfg.twisting.scale * scale_t_rot[:, None, None]

        # Select motif
        trans_pred = model_out['pred_trans']  # [B, num_res, 3]
        R_pred = model_out['pred_rotmats']  # [B, num_res, 3, 3]

        # Estimate p(motif|predicted_motif)
        trans_motif, R_motif = trans_1[:, true_motif_mask[0]], rotmats_1[:, true_motif_mask[0]]
        grad_Log_delta_R, grad_x_log_p_motif, logs, max_log_p_idx = twisting.grad_log_lik_approx(R_pred, trans_pred, R_motif, trans_motif, Log_delta_R, delta_x, None, rt_sq_trans, rt_sq_rot, F, twist_potential_rot=self._cfg.twisting.potential_rot, twist_potential_trans=self._cfg.twisting.potential_trans)

        # Get best motif location and make motif_mask
        # motif_locations_b = motif_locations[max_log_p_idx]
        motif_locations_b = [motif_locations[idx] for idx in max_log_p_idx.tolist()]
        B = trans_1.shape[0]
        motif_mask = torch.zeros((B, trans_pred.shape[1])).bool().to(self._device)
        for b in range(B):
            for (start, end) in motif_locations_b[b]:
                motif_mask[b, start:end+1] = True
            # Compute update
            trans_t, rotmats_t = twisting.step(trans_t, rotmats_t, grad_x_log_p_motif, grad_Log_delta_R, d_t, trans_scale_t, rot_scale_t, self._cfg.twisting.update_trans, self._cfg.twisting.update_rot, self._cfg.twisting.max_rot_grad_norm)

        # # log metrics
        # for k, v in logs.items():
        #     logs_traj[k].append(v)
        # grad_x_log_p_motif_norm = torch.norm(grad_x_log_p_motif, dim=[-1]).mean(-1)#.item()
        # grad_Log_delta_R_norm = torch.norm(grad_Log_delta_R, dim=[-1, -2]).mean(-1)#.item()
        # logs_traj['grad_x_log_p_motif_norm'].append(grad_x_log_p_motif_norm)
        # logs_traj['grad_Log_delta_R'].append(grad_Log_delta_R_norm)
        # logs_traj['trans_scale_t'].append(trans_scale_t.squeeze())
        # logs_traj['rot_scale_t'].append(rot_scale_t.squeeze())
        # logs_traj['guidance_trans_norm'].append(grad_x_log_p_motif_norm * trans_scale_t.squeeze())
        # logs_traj['guidance_Log_delta_R_norm'].append(grad_Log_delta_R_norm * rot_scale_t.squeeze())

        # prevent from any memory leak
        for key, value in model_out.items():
            model_out[key] = value.detach().requires_grad_(False)

        return trans_t, rotmats_t, motif_mask, logs_traj
