import torch, time, os
import torch.nn as nn
from torch import Tensor
import numpy as np
from ..rigid_utils import Rigid, Rotation
from ..geometry import  frames_torsions_to_atom14
import torch.nn.functional as F
from SDE_model.model.latent_model import LatentMDGenModel

__all__ = ['ProteinSDE']


def freeze_ipa(model):
    if hasattr(model, 'ipa_layers'):
        for layer in model.ipa_layers:
            for name, param in layer.named_parameters():
                param.requires_grad = False
                

def mean_flat(x, mask):
    """
    Take the mean over all non-batch dimensions.
    
    """
    return torch.sum(x * mask, dim=list(range(1, len(x.size())))) / torch.sum(mask, dim=list(range(1, len(x.size()))))


class ProteinSDE(nn.Module):
    def __init__(
        self,
        config
    ):
        super().__init__()
        
        self.config = config
        self.device = torch.device('cuda')

        self.network = LatentMDGenModel(self.config, self.config.latent_dim).to(self.device)
        
        total_params = sum(p.numel() for p in self.network.parameters())
        trainable_params = sum(p.numel() for p in self.network.parameters() if p.requires_grad)
        total_params_mb = total_params / (1024 ** 2)  
        trainable_params_mb = trainable_params / (1024 ** 2)  
        print(f'{self.network._get_name()} #Params: {total_params_mb:.2f} MB, Trainable Params: {trainable_params_mb:.2f} MB')
        

    def brownian_forward(self, protein: Tensor, noise_scale) -> Tensor:
        """Add Gaussian noise of given level

        Args:
            protein (Tensor)

        Returns:
            protein_next (Tensor)
        """
        noise = noise_scale * torch.randn_like(protein) 
        protein = protein + noise
        return protein , noise

    
    def sde_forward(self, 
                x,
                noise_scale,
                t,
                mask,
                start_frames, 
                end_frames,
                x_cond, 
                x_cond_mask, 
                aatype, 
                mode) -> Tensor:
        

        noise_input, noise = self.brownian_forward(x, noise_scale)
        pred = self.network(noise_input, t, mask,start_frames, end_frames,x_cond, x_cond_mask, aatype, noise)
        
        if mode == 'gradient':
            return pred
        else:
            return pred + x   ## default
  
    def inference(self, batch, totoal_frames, threshold, mode):
        prep = self.prep_batch(batch)
        indices = torch.arange(0, min(totoal_frames - threshold * 2 + 1, totoal_frames + 1), threshold)
        x = prep['latents'][:, 0:threshold, ...] 
        mask = prep['model_kwargs']['mask'][:, 0:threshold, ...]
        start_frames =  prep['model_kwargs']['start_frames']
            
        generated_traj = []
        generated_traj.append(x)

        
        batch_loss = torch.tensor(0.0, device='cuda')
        
        for step in indices:
            noise = 5
            t = torch.full((x.size(0),), noise)  
            t = t.to('cuda')

            pred = self.sde_forward(
                    x = x,
                    noise_scale = noise,
                    t = t,
                    mask = mask,
                    start_frames = start_frames, 
                    end_frames=None,
                    x_cond=None, 
                    x_cond_mask=None, 
                    aatype = prep['model_kwargs']['aatype'], 
                    mode = mode,
            )
            
            generated_traj.append(pred)
            x = pred
            offsets_pred_current = pred[..., :7]
            frames_pred = Rigid.from_tensor_7(offsets_pred_current, normalize_quats=True)
            frames_pred._rots = frames_pred._rots[:,0,...]
            frames_pred._trans = frames_pred._trans[:,0,...]
            start_frames =  frames_pred
            
   
        generated_traj = torch.cat(generated_traj, dim=1)
        offsets_pred = generated_traj[..., :7]
        torsions_pred  = generated_traj[..., 7:] 

        frames_pred = Rigid.from_tensor_7(offsets_pred, normalize_quats=True)
        B, T, L, _ = offsets_pred.shape
        
        torsions_pred = F.pad(torsions_pred, (0, 8), mode='constant', value=0)
        
        atom14_pred = frames_torsions_to_atom14(frames_pred, torsions_pred.view(B, T, L, 7, 2),
                                                prep['model_kwargs']['aatype'][:, None].expand(B, T, L)) ##
        return atom14_pred, batch_loss
            

    def prep_batch(self, batch):
        rigids = Rigid(
            trans=batch['trans'],
            rots=Rotation(rot_mats=batch['rots'])
        )  # B, T, L
        B, T, L = rigids.shape


        offsets = rigids.to_tensor_7()
            
        offsets[..., :4] *= torch.where(offsets[:, :, :, 0:1] < 0, -1, 1)
        frame_loss_mask = batch['mask'].unsqueeze(-1).expand(-1, -1, 7)

        
        batch['torsions'] = batch['torsions'][..., :3,:]
        batch['torsion_mask'] = batch['torsion_mask'][..., :3]
        
        torsion_loss_mask = batch['torsion_mask'].unsqueeze(-1).expand(-1, -1, -1, 2).reshape(B, L, 6)

        latents = torch.cat([offsets, batch['torsions'].view(B, T, L, 6)], -1)

        loss_mask = torch.cat([frame_loss_mask, torsion_loss_mask], -1)
        loss_mask = loss_mask.unsqueeze(1).expand(-1, T, -1, -1)
 
        
        cond_mask = torch.zeros(B, T, L, dtype=int, device=offsets.device)
        cond_mask[:, 0] = 1
        aatype_mask = torch.ones_like(batch['seqres'])
        
        backbone_indices = torch.tensor([0, 1, 2, 3], device= batch['arr'].device)
        batch['arr'] = torch.index_select(batch['arr'], dim=-2, index=backbone_indices)
        
        non_zero_mask = batch['mask'] != 0
        protein_len = torch.sum(non_zero_mask)
        aatype_mask = torch.ones_like(batch['seqres'])
        aatype = torch.where(aatype_mask.bool(), batch['seqres'], 20)
        

        return {
            'arr' : batch['arr'],
            'rigids': rigids,
            'latents': latents,
            'loss_mask': loss_mask,
            'name': batch['name'],
            'seqres': batch['seqres'],
            'protein_len': protein_len,
            
            'model_kwargs': {
                'start_frames': rigids[:, 0],
                'frames': rigids,
                'end_frames': rigids[:, -1],
                'mask': batch['mask'].unsqueeze(1).expand(-1, T, -1),
                'aatype': aatype,
                'x_cond': torch.where(cond_mask.unsqueeze(-1).bool(), latents, 0.0),
                'x_cond_mask': cond_mask, }}
    