import numpy as np
import torch
from openfold.utils.rigid_utils import Rigid
from entity import entity_constants as ec
from . import utils
from utils.funcs import (
    generate_Cbeta
)

#https://github.com/scipy/scipy/blob/main/scipy/spatial/transform/_rotation.pyx
def rmsdalign(a, b, *, b_full=None, weights=None, eps=1e-6): # alignes B to A  # [*, N, 3]
    with torch.cuda.amp.autocast(enabled=False):
        res_dtype = a.dtype
        a = a.to(torch.float32)
        b = b.to(torch.float32)
        B = a.shape[:-2]
        N = a.shape[-2]
        if weights == None:
            weights = a.new_ones(*B, N)
        else:
            weights = weights.to(torch.float32)
        weights = weights.unsqueeze(-1)
        a_mean = (a * weights).sum(-2, keepdims=True) / weights.sum(-2, keepdims=True)
        a = a - a_mean
        b_mean = (b * weights).sum(-2, keepdims=True) / weights.sum(-2, keepdims=True)
        b = b - b_mean
        B = torch.einsum('...ji,...jk->...ik', weights * a, b)
        B += torch.randn_like(B) * eps
        u, s, vh = torch.linalg.svd(B)

        sgn = torch.sign(torch.linalg.det(u @ vh))
        s = s.clone()
        s[...,-1] *= sgn
        u = u.clone()
        u[...,:,-1] *= sgn.unsqueeze(-1)
        C = u @ vh # c rotates B to A
        if b_full is not None:
            b_full = b_full.to(torch.float32)
            b_full = b_full - b_mean
            b_full = b_full @ C.mT
            b_full = b_full + a_mean
            return b_full.to(res_dtype)
        else:
            return (b @ C.mT + a_mean).to(res_dtype)

class SamplePrior:
    def __init__(self, a_prot=3/(3.8**2), a_lig=1, gaussion=True, traditional=False, g_scale=None):
        self.a_prot = a_prot
        self.a_lig = a_lig
        self.eps = 1e-4
        self.traditional = traditional
        self.gaussion = gaussion
        self.g_scale = g_scale
        
    def sample(self, edges, device):
        batch_dims = edges.shape[:-3]
        edges_flatten = edges.reshape(-1, *edges.shape[-3:])
        res_list = []
        for data in edges_flatten:
            res_list.append(self.g_scale * torch.randn(data.shape[0], 3, device=device))
        res = torch.stack(res_list).reshape(*batch_dims, edges.shape[-3], 3)
        return res

class FlowMatching(torch.nn.Module):
    def __init__(self, *, model, cfg):
        super().__init__()
        self.model = model
        self.cfg = cfg
        self.harmonic_prior = SamplePrior(gaussion=True, g_scale=cfg.sigma_data)

    def data_postprocess(self, batch, idx, mode, t=None):
        device = batch['token_type'].device
        batch_dims = batch['seq_length'].shape
        
        noisy = self.harmonic_prior.sample(edges=batch['edges'], device=device)
        noisy = rmsdalign(batch['pseudo_beta'], noisy, weights=batch['pseudo_beta_mask']).detach()
        
        if t is None:
            t = torch.tensor([1.], device=device)
        elif type(t) != torch.tensor:
            t = torch.tensor([t], dtype=torch.float, device=device)

        noisy_beta = ((1 - t[:,None,None]) * batch['pseudo_beta'] + t[:,None,None] * noisy)
        if len(batch_dims) == 0:
            noisy_beta = noisy_beta[0]
        batch["input_pos"] = noisy_beta
        batch["sc_cb_pos"] = torch.zeros_like(noisy_beta)

        batch['t'] = t
        diffused_mask = batch['seq_mask'].clone()
        hint_mask = torch.zeros_like(batch['seq_mask'])
        hint_pos = torch.zeros_like(batch['pseudo_beta'])

        diffused_mask[batch["entity_type"][:, ec.entity_type_order["molecule"]] == 1] = 0

        batch["hint_pos"] = hint_pos
        batch["diffused_mask"] = diffused_mask
        batch["fixed_mask"] = (1 - diffused_mask) * batch["seq_mask"]
        batch["hint_mask"] = hint_mask
        return batch

    def sample(self, batch):
        violation = 100
        while violation > 1:
            batch_type = batch['seq_mask'].dtype

            device = batch['token_type'].device
            noisy = self.harmonic_prior.sample(edges=batch['edges'], device=device)
            batch["input_pos"] = noisy.to(batch_type)
            schedule = np.linspace(1, 0, 100)
            
            entity_type = batch["entity_type"].argmax(dim=-1)
            batch["t"] = torch.ones(1, device=device).to(batch_type)

            diffused_mask = batch['seq_mask'].clone()

            is_mol = batch["entity_type"][..., ec.entity_type_order["molecule"]] == 1
            
            batch["diffused_mask"] = diffused_mask.to(batch_type)
            batch["fixed_mask"] = ((1 - diffused_mask) * batch["seq_mask"]).to(batch_type)
            batch["sc_cb_pos"] = torch.zeros_like(batch["pseudo_beta"]).to(batch_type)
            
            all_bb_prots = []
            denoise_traj = []

            denoise_traj.append({
                "pseudo_beta": batch["pseudo_beta"].cpu().numpy(),
                "noisy": noisy.cpu().numpy(),
            })

            for t, s in zip(schedule[:-1], schedule[1:]):
                model_out = self.model(batch)
            
                gt_bb_rigid = Rigid.from_tensor_4x4(
                    batch['rigidgroups_gt_frames'])[..., 0]
                fixed_mask = batch["fixed_mask"]
                rigids_0 = gt_bb_rigid.to_tensor_7()
                bb_mask = batch['seq_mask']
                rigids_pred = (model_out["final_rigids"] * (1 - fixed_mask[..., None]) + rigids_0 * fixed_mask[..., None]) * bb_mask[..., None]
                psi_pred = model_out["psi"]
                bb_representations = utils.compute_backbone(Rigid.from_tensor_7(rigids_pred), psi_pred, entity_type)
                atomFull_pos = bb_representations[0].to(rigids_pred.device)

                atomFull_shape = atomFull_pos.shape
                atomFull_pos_flattened = atomFull_pos.reshape(-1, *atomFull_pos.shape[-3:])
                bb_mask_flattened = bb_mask.reshape(-1, *bb_mask.shape[-1:])

                for i in range(atomFull_pos_flattened.shape[0]):
                    atomFull_pos_flattened[i][:, :37, :] = utils.adjust_oxygen_pos(atomFull_pos_flattened[i][:, :37, :], bb_mask_flattened[i])

                atomFull_pos = atomFull_pos_flattened.reshape(*atomFull_shape)

                ca_idx = ec.atom_order["CA"]
                c_idx = ec.atom_order["C"]
                n_idx = ec.atom_order["N"]
                molAtom_idx = ec.atom_order["*MolAtom"]

                all_bb_prots.append(atomFull_pos.cpu().detach().numpy())

                pseudo_beta = generate_Cbeta(
                    N=atomFull_pos[..., n_idx, :],
                    Ca=atomFull_pos[..., ca_idx, :],
                    C=atomFull_pos[..., c_idx, :]
                )

                pseudo_beta[is_mol] = atomFull_pos[is_mol][..., molAtom_idx, :]
                noisy = rmsdalign(pseudo_beta, noisy)
                noisy = (s / t) * noisy + (1 - s / t) * pseudo_beta

                batch['input_pos'] = noisy.to(batch_type)
                batch['t'] = (torch.ones(1, device=noisy.device) * s).to(batch_type)
                batch["sc_cb_pos"] = pseudo_beta.to(batch_type)

                denoise_traj.append({
                    "pseudo_beta": pseudo_beta.cpu().numpy(),
                    "noisy": noisy.cpu().numpy(),
                })

            violation = self.calculate_loss(batch, model_out)["violation"]

        flip = lambda x: np.flip(np.stack(x), (0,))
        all_bb_prots = flip(all_bb_prots)
        return {
            "prot_traj": all_bb_prots,
            "denoise_traj": denoise_traj,
        }
    
    def calculate_loss(self, batch, model_out):
        entity_type = batch["entity_type"].argmax(dim=-1)
        bb_mask = batch['seq_mask'].clone()
        is_mol = (entity_type == ec.entity_type_order["molecule"]) * batch["seq_mask"]
        gt_bb_rigid = Rigid.from_tensor_4x4(
            batch['rigidgroups_gt_frames'])[..., 0]
        gt_bb_rigid = utils.make_mol_rigid(gt_bb_rigid, batch["fape_frame_idx"], is_mol)

        fixed_mask = batch["fixed_mask"]
        rigids_0 = gt_bb_rigid.to_tensor_7()

        rigids_pred = (model_out["outputs"]["frames"] * (1 - fixed_mask[..., None]) + rigids_0 * fixed_mask[..., None]) * bb_mask[..., None]

        rigids_pred_mol_4x4 = utils.make_mol_rigid(Rigid.from_tensor_7(rigids_pred), batch["fape_frame_idx"], is_mol).to_tensor_4x4()
        psi_pred = model_out["psi"]

        bb_representations = utils.compute_backbone(Rigid.from_tensor_4x4(rigids_pred_mol_4x4), psi_pred, entity_type)
        atomFull_pos_pred = bb_representations[0].to(rigids_pred.device)
        atom14_pos_pred = bb_representations[3].to(rigids_pred.device)

        atom14_pos_pred_mask = atom14_pos_pred[-1].new_tensor(
            [1.0] * 5 + [0.0] * (14 - 5)
        ).reshape(
            [1] * (len(atom14_pos_pred[-1].shape)-2) + [14]
        ).expand(
            *atom14_pos_pred[-1].shape[:-1]
        ) * batch["atom14_atom_exists"]

        # between residue loss for protein
        connection_violations = utils.between_residue_bond_loss(
            pred_atom_positions=atomFull_pos_pred,
            pred_atom_mask=batch["all_atom_mask"],
            residue_index=batch["token_index"],
            aatype=batch["token_type"],
            tolerance_factor_soft=self.cfg.gm.loss.violation.violation_tolerance_factor,
            tolerance_factor_hard=self.cfg.gm.loss.violation.violation_tolerance_factor,
        )

        # Compute the Van der Waals radius for every atom
        # (the first letter of the atom name is the element type).
        # Shape: (N, 14).
        atomtype_radius = [
            ec.van_der_waals_radius[name[0]]
            for name in ec.atom_types
        ]
        atomtype_radius = atomFull_pos_pred.new_tensor(atomtype_radius)
        atom14_atom_radius = (
            batch["atom14_atom_exists"]
            * atomtype_radius[batch["residx_atom14_to_atomFull"]]
        )

        # Compute the between residue clash loss.
        # For ligand: no internal clash loss
        between_residue_clashes = utils.between_residue_clash_loss(
            atom14_pred_positions=atom14_pos_pred,
            atom14_atom_exists=atom14_pos_pred_mask,
            atom14_atom_radius=atom14_atom_radius,
            residue_index=batch["token_index"],
            overlap_tolerance_hard=self.cfg.gm.loss.violation.clash_overlap_tolerance,
            overlap_tolerance_soft=self.cfg.gm.loss.violation.clash_overlap_tolerance,
        )

        # next: combine the results to calculate loss value
        violation_loss = torch.mean(
            torch.sum(
                between_residue_clashes["per_atom_loss_sum"]
            ) / (1e-6 + torch.sum(batch["atom14_atom_exists"]))
            + connection_violations["c_n_loss_mean"]
            + connection_violations["ca_c_n_loss_mean"]
            + connection_violations["c_n_ca_loss_mean"]
        )

        loss_breakdown = {
            "violation": violation_loss,
        }

        return loss_breakdown
