from functools import partial
import logging

import torch
import torch.nn as nn

from TopoDiff.model.encoder import Encoder_v1

from TopoDiff.model.diffuser.diffuser import SE3Diffuser
from TopoDiff.model.embedder2 import InputEmbedder2
from TopoDiff.model.backbone2 import Backbone2
# from TopoDiff.model.aux_head import AuxiliaryHeads

from TopoDiff.model.structure_module import StructureModuleHelper

from myopenfold.utils.rigid_utils import Rotation, Rigid

from TopoDiff.utils.debug import print_shape, log_var

from TopoDiff.utils.transforms import make_one_hot
from myopenfold.utils.tensor_utils import tensor_tree_map

logger = logging.getLogger("TopoDiff.model.diffusion")

class Diffusion(nn.Module):

    def __init__(self, config, depth = 0, log = False) -> None:
        super().__init__()

        self.depth = depth
        self.log = log

        self.config = config

        if 'Encoder' not in config.Global or config.Global.Encoder is None:
            logger.info('Not using topology encoder.')
            self.encoder = None
        elif config.Global.Encoder == 'Encoder_v1':
            self.encoder = Encoder_v1(config.Encoder_v1, depth = self.depth + 1, log = self.log)
            self.encoder_config = config.Encoder_v1
        else:
            raise NotImplementedError(f'Encoder {config.Global.Encoder} not implemented')

        self.diffuser = SE3Diffuser(config.Diffuser, depth = self.depth + 1, log = self.log)

        if config.Global.Embedder == 'Embedder_v2':
            self.embedder = InputEmbedder2(config.Embedder_v2, depth = self.depth + 1, log = self.log)
        else:
            raise NotImplementedError(f'Embedder {config.Global.Embedder} not implemented')
        
        if config.Global.Backbone == 'Backbone_v2':
            self.backbone = Backbone2(config.Backbone_v2, depth = self.depth + 1, log = self.log)
        else:
            raise NotImplementedError(f'Backbone {config.Global.Backbone} not implemented')

        # self.aux_heads = AuxiliaryHeads(config.Aux_head, depth = self.depth + 1, log = self.log)

        self.dummy_param = nn.Parameter(torch.empty(0), requires_grad = False)

    @property
    def device(self):
        return self.dummy_param.device

    def _log(self, text, tensor = 'None'):
        if self.log:
            log_var(text, tensor, depth = self.depth)


    def sample_step(self, feat: dict, frame_noised, prev = None, reverse_sample = False, inplace_safe = False, is_first = False, is_last = False,
                    intype = 'tensor_7', outtype = 'tensor_7'):
        result = {}
        
        # embed the input
        seq_emb, pair_emb, emb_output = self.embedder(feat, prev, inplace_safe, is_first, is_last)

        # predict the ground truth frame
        backbone_output = self.backbone(feat, frame_noised, seq_emb, pair_emb, inplace_safe)

        #. if it is not the last layer, we only need to return these features
        if not is_last:
            result['frame_hat'] = backbone_output['frame_hat']
            result['seq_emb'] = backbone_output['seq_emb']
            result['pair_emb'] = backbone_output['pair_emb']
        else:
            result.update(emb_output)

            result.update(backbone_output)

        ## do anything to backbone_output['sm']
        del backbone_output

        if reverse_sample:
            # reverse sample
            frame_denoised = self.diffuser.reverse_sample(frame_noised, result['frame_hat'], feat['timestep'], feat['frame_mask'], intype = intype, outtype = outtype)
            result['frame_denoised'] = frame_denoised

        return result
    
    def _init_feat(self, num_samples = 1, num_res = 150, timestep = 200):
        feat = {}
        feat['timestep'] = torch.ones(num_samples, dtype = torch.long) * timestep
        has_break = torch.zeros(num_samples, num_res, 1, dtype = torch.long)
        seq_type = torch.ones(num_samples, num_res, dtype = torch.long) * 20
        feat['seq_type'] = seq_type
        feat['seq_feat'] = torch.cat([has_break, make_one_hot(seq_type, 21)], dim = -1).float()
        feat['seq_idx'] = torch.arange(num_res, dtype = torch.long).unsqueeze(0).repeat(num_samples, 1)
        feat['seq_mask'] = torch.ones(num_samples, num_res, dtype = torch.bool)
       
        feat['frame_mask'] = torch.ones(num_samples, num_res, dtype = torch.bool)
        feat['frame_gt'] = torch.zeros(num_samples, num_res, 4, 4, dtype = torch.float)
        feat['frame_gt'][..., :3, :3] = torch.eye(3, dtype = torch.float)
        feat['frame_gt'][:, 3, 3] = 1.0

        return feat 
    
    def sample_latent_conditional(self, latent, return_traj = False, num_res = None, timestep = 200, 
                             return_frame = False, return_position = False, reconstruct_position = False, length_variance = 0.15,
                             **kwargs):
        """Unconditional generation of samples.

        feat(Minimal):
            - timestep [*]
            - seq_type [*, N_res]
            - seq_idx [*, N_res]
            - seq_mask [*, N_res]
            - frame_mask [*, N_res]
            - seq_feat [*, N_res, 22]
        """
        if 'num_samples' in kwargs:
            logger.warning('`num_samples` is deprecated. Will be set to 1 in all cases.')
        num_samples = 1
        device = self.device

        if return_traj:
            if not return_frame and not reconstruct_position:
                raise ValueError('return_traj is True but return_frame and reconstruct_position are both False.')
            if return_frame:
                frame_noised_record = []
                frame_hat_record = []
            if return_position:
                if reconstruct_position:
                    self._prepare_helper()
                coord_noised_record = []
                coord_hat_record = []

        with torch.no_grad():
            latent = latent.to(self.device)[None].float()
            if num_res is None:
                # NOTE infer the length of latent
                length_pred = -self.aux_heads.infer_length(latent, scale=True)[:, 0]
                num_res = torch.randint(int((length_pred * (1 - length_variance)).item()), int((length_pred * (1 + length_variance)).item()), (1,)).item()
                print('`num_res` is not provided. Will infer from `latent`: %d' % num_res)
                if length_pred > 350:
                    print('Warning: the length of the predicted protein is too long: %d, forcing to be 350' % length_pred)
                    num_res = 350
                if length_pred < 50:
                    print('Warning: the length of the predicted protein is too short: %d, forcing to be 50' % length_pred)
                    num_res = 50
            else:
                # print('`num_res` is provided: %d' % num_res)
                pass

            feat = {'batch_idx': torch.zeros(num_samples, dtype = torch.long)}
            feat['latent_z'] = latent
            feat.update(self._init_feat(num_samples = num_samples, num_res = num_res, timestep = timestep))
            feat = tensor_tree_map(lambda x: x.to(device), feat)
                
            # forward sample
            frame_noised = self.diffuser.forward_sample_marginal(feat['frame_gt'], feat['timestep'], feat['frame_mask'], intype = 'tensor_4x4', outtype = 'tensor_7')

            if return_traj:
                if return_frame:
                    frame_noised_record.append(frame_noised.detach().cpu())
                if return_position:
                    if reconstruct_position:
                        coord_noised = self.helper.reconstruct_backbone_position_without_torsion_wrap(
                            frame_pred = frame_noised, seq_type = feat['seq_type'], intype='tensor_7'
                        )
                    else:
                        raise NotImplementedError('Not implemented yet.')
                    coord_noised_record.append(coord_noised.detach().cpu())

            is_first_tag = True
            prev = None

            if 'sample_idx' in feat:
                sample_idx = feat['sample_idx']
            else:
                sample_idx = torch.zeros(1, dtype = torch.long)  # NOTE num sample will always be 1
            # self-conditioning
            for i in range(timestep):
                res_dict = self.sample_step(feat, frame_noised, prev, reverse_sample = True, inplace_safe = True, is_first = is_first_tag, intype = 'tensor_7', outtype = 'tensor_7')
                seq_emb, pair_emb, frame_hat, frame_denoised = res_dict['seq_emb'], res_dict['pair_emb'], res_dict['frame_hat'], res_dict['frame_denoised']

                # reference management
                if self.config.Global.infer_no_recyc:
                    # print('Inference with reycling disabled..')
                    prev = None
                else:
                    # print('Inference with reycling enabled..')
                    frame_hat_trans = frame_hat[..., 4:]
                    prev = [seq_emb, pair_emb, frame_hat_trans]
                frame_noised = frame_denoised
                feat['timestep'] -= 1
                is_first_tag = False

                if return_traj:
                    if return_frame:
                        frame_hat_record.append(frame_hat.detach().cpu())
                        frame_noised_record.append(frame_denoised.detach().cpu())
                    if return_position:
                        if reconstruct_position:
                            coord_hat = self.helper.reconstruct_backbone_position_without_torsion_wrap(
                                frame_pred = frame_hat, seq_type = feat['seq_type'], intype='tensor_7'
                            )
                            coord_denoised = self.helper.reconstruct_backbone_position_without_torsion_wrap(
                                frame_pred = frame_denoised, seq_type = feat['seq_type'], intype='tensor_7'
                            )
                        coord_noised_record.append(coord_denoised.detach().cpu())
                        coord_hat_record.append(coord_hat.detach().cpu())

                del res_dict, seq_emb, pair_emb      # frame_hat, frame_denoised

        result = {}
        result['frame_denoised'] = frame_noised[0].cpu()
        result['frame_hat'] = frame_hat[0].cpu()

        #. concat along the first dimension and reverse the order
        if return_traj:
            if return_frame:
                result['frame_noised_record'] = torch.cat(frame_noised_record[::-1], dim = 0)
                result['frame_hat_record'] = torch.cat(frame_hat_record[::-1], dim = 0)
            if return_position:
                result['coord_noised_record'] = torch.cat(coord_noised_record[::-1], dim = 0)
                result['coord_hat_record'] = torch.cat(coord_hat_record[::-1], dim = 0)
            result['noised_timestep_record'] = torch.arange(timestep+1, dtype = torch.long, device = device)
            result['hat_timestep_record'] = torch.arange(timestep, dtype = torch.long, device = device)

        if reconstruct_position:
            coord_denoised = self.helper.reconstruct_backbone_position_without_torsion_wrap(
                frame_pred = frame_denoised, seq_type = feat['seq_type'], intype='tensor_7'
            )
            coord_hat = self.helper.reconstruct_backbone_position_without_torsion_wrap(
                frame_pred = frame_hat, seq_type = feat['seq_type'], intype='tensor_7'
            )
            result['coord_denosied'] = coord_denoised[0].cpu()
            result['coord_hat'] = coord_hat[0].cpu()

        return result

    
    def _prepare_helper(self):
        """Prepare StructureModuleHelper for reconstructing position."""
        if not hasattr(self, 'helper'):
            self.helper = StructureModuleHelper().to(self.device)
        return

    def encode_topology(self, feat: dict):
        """
        Args:
            feat:
                features for VAE encoder:
                    encoder_feats: [*, N_res, C_encoder]
                        The input features of the encoder
                    encoder_coords: [*, N_res, 3]
                        The input CA coordinates of the encoder
                    encoder_mask: [*, N_res]
                        The mask of the encoder input
                    encoder_adj_mat: [*, N_res, N_res]
                        The adjacency matrix of the encoder input
        """
        assert hasattr(self, 'encoder_config'), '`encoder_config` is not defined.'
        assert hasattr(self, 'encoder'), '`encoder` is not defined.'

        # print_shape(feat)

        with torch.set_grad_enabled(self.encoder_config.trainable):
            # encoder
            #. [B, latent_dim * 2]
            enc_output = self.encoder(feat)

            # reparameterization
            #. [B, latent_dim]
            latent_mu = enc_output[:, :self.encoder_config.latent_dim]
            #. [B, latent_dim]
            latent_sigma = torch.exp(enc_output[:, self.encoder_config.latent_dim:]) + self.encoder_config.eps
            #. [B, latent_dim]
            latent_z = latent_mu + latent_sigma * torch.randn_like(latent_mu) * self.encoder_config.temperature

        # pack encoder result
        res_dict = {}
        res_dict['latent_mu'] = latent_mu
        res_dict['latent_logvar'] = enc_output[:, self.encoder_config.latent_dim:]
        res_dict['latent_z'] = latent_z

        return res_dict
    
    def infer_latent(self, feat: dict, ):
        """
        Args:
            feat:
                latent_z: [*, latent_dim]
                    The latent code

        Returns:
            res_dict:
                length_logits: [*, logit_dim]
                    The logits of the length prediction
        """

        res_dict = self.aux_heads.infer_latent(feat)

        return res_dict
        









