import os
import torch
from torch import nn
import numpy as np

from TopoDiff.utils import so3_utils as so3

#. utils
from TopoDiff.utils.debug import log_var

class SO3Diffuser(nn.Module):
    def __init__(self, config, log = False, depth = 0, debug = False):
        super(SO3Diffuser, self).__init__()

        self.config = config
        self.log = log
        self.depth = depth
        self.debug = debug

        self.reverse_strategy = self.config.reverse_strategy
        self.reverse_parm = self.config[self.reverse_strategy]

        self._load_cache()
        self._precompute_time_schedule()

    def _log(self, text, tensor = 'None', additional = None):
        if self.log:
            log_var(text, tensor, depth = self.depth, additional = additional)

    def forward_sample(self, rot, t, rot_mask = None):
        """sample from the forward distrbution. t-1 -> t

        Args:
            rot (torch.Tensor): [*, N_res, 3, 3]
                The orientation in rotation matrix form at time_step = t - 1.
            rot_mask (torch.Tensor): [*, N_res]
                The mask of the rotation matrix.
            t (torch.Tensor): [*,]
                The current time step.

        Returns:
            torch.Tensor: [*, N_res, 3, 3]
                The orientation in rotation matrix form at time_step = t.
        """
        self._log('In SO3Diffuser.forward_sample: start')  # DEBUG1
        self._log('rot', rot)  # DEBUG1
        self._log('t', t)  # DEBUG1
        self._log('rot_mask', rot_mask)  # DEBUG1

        if rot_mask is None:
            rot_mask = torch.ones(rot.shape[:-2], dtype=torch.bool).to(rot.device)
            
        #. [*, N_res]
        sigma = self.sigma_interval_schedule[t][..., None].expand(rot.shape[:-2])
        # if self.log: print("sigma", sigma.shape)  # DEBUG

        #. [*, N_res, 3]
        noise_vec = self.sample_vec(sigma).to(rot.device)
        # if self.log: print("noise_vec", noise_vec.shape)  # DEBUG

        #. [*, N_res, 3, 3]
        noise_rot = so3.so3_Exp(noise_vec)
        # if self.log: print("noise_rot", noise_rot.shape)  # DEBUG
        # if self.log: print(rot, noise_rot)

        #. [*, N_res, 3, 3]
        rot_t = rot * (~rot_mask[..., None, None]) + rot @ noise_rot * rot_mask[..., None, None]
        # if self.log: print("rot_t", rot_t.shape)  # DEBUG

        # if self.log: print(rot, noise_rot)

        if self.debug:
            return rot_t, noise_rot
        return rot_t
    
    def forward_sample_marginal(self, rot, t, rot_mask = None):
        """sample from the forward distrbution. 0 -> t

        Args:
            rot (torch.Tensor): [*, N_res, 3, 3]
                The orientation in rotation matrix form at time_step = 0.
            rot_mask (torch.Tensor): [*, N_res]
                The mask of the rotation matrix.
            t (torch.Tensor): [*,]
                The current time step.

        Returns:
            torch.Tensor: [*, N_res, 3, 3]
                The orientation in rotation matrix form at time_step = t.
        """
        self._log('In SO3Diffuser.forward_sample_marginal: start')  # DEBUG1
        self._log('rot', rot)  # DEBUG1
        self._log('t', t)  # DEBUG1
        self._log('rot_mask', rot_mask)  # DEBUG1

        if rot_mask is None:
            rot_mask = torch.ones(rot.shape[:-2], dtype = torch.bool).to(rot.device)
        # if self.log: print('rot_mask', rot_mask, rot_mask.dtype)
        # if self.log: print("~rot_mask[..., None, None]", (~rot_mask[..., None, None]))
        
        # t_mask = torch.ones_like(t, dtype=torch.bool, device=t.device)
        # t_mask[t == 0] = False
        # rot_mask = rot_mask * t_mask[..., None]

        #. [*, N_res]
        # print('t', t.device)
        # NOTE 2023.08.25 in torch 2.0, `t` need to be manually passed to cpu
        sigma = self.sigma_schedule[t.cpu()][..., None].expand(rot.shape[:-2])
        # if self.log: print(sigma)
        # if self.log: print("sigma", sigma.shape, sigma.dtype)  # DEBUG

        #. [*, N_res, 3]
        noise_vec = self.sample_vec(sigma).to(rot.device)
        self._log('noise_vec', noise_vec, str(noise_vec.requires_grad) + str(noise_vec.is_leaf) + str(noise_vec.device))  # DEBUG1
        # if self.log: print("noise_vec", noise_vec.shape, noise_vec.dtype)  # DEBUG

        #. [*, N_res, 3, 3]
        noise_rot = so3.so3_Exp(noise_vec)
        # if self.log: print("noise_rot", noise_rot.shape, noise_rot.dtype)  # DEBUG
        # if self.log: print('rot before compose',rot)

        # print('rot', rot.device)
        # print('noise_rot', noise_rot.device)
        # print('rot_mask', rot_mask.device)

        #. [*, N_res, 3, 3]
        rot_t = rot * (~rot_mask[..., None, None]) + rot @ noise_rot * rot_mask[..., None, None]
        # if self.log: print('rot after compose', rot_t)

        if self.debug:
            return rot_t, noise_rot
        return rot_t
    
    def sample_from_noise(self, rot, rot_mask):
        """Sample from the noise distribution.

        Args:
            rot (torch.Tensor): [*, N_res, 3, 3] (could be Nonr)
                The orientation in rotation matrix form at time_step = 0.
            rot_mask (torch.Tensor): [*, N_res]
                The mask of the rotation matrix.

        Returns:
            torch.Tensor: [*, N_res, 3, 3]
                The orientation in rotation matrix form at time_step = t.
        """
        self._log('In SO3Diffuser.sample_frome_noise: start')
        self._log('rot', rot)
        self._log('rot_mask', rot_mask)

        if rot_mask.dtype != torch.bool:
            rot_mask = rot_mask.bool()
        
        if rot is None:
            n_batch_dim = rot_mask.ndim
            rot = torch.eye(3)[(None,) * n_batch_dim].repeat(rot_mask.shape + (1, 1)).to(rot_mask.device)
        
        t = torch.ones(rot_mask.shape[:-1], dtype = torch.long) * self.config.T

        rot_t = self.forward_sample_marginal(rot, t, rot_mask)

        return rot_t

    def reverse_sample(self, rot_t, rot_0_hat, t, rot_mask = None):
        """Sample from the reverse distribution. t -> t-1.
            This implementation is similar to 'Antigen-Specific Antibody Design and Optimization with Diffusion-Based Generative Models for Protein Structures'
            where each new orientation is sampled as `rot_0_hat` with some additional noise

        Args:
            rot_t (torch.Tensor): [*, N_res, 3, 3]
                The orientation in rotation matrix form at time_step = t.
            rot_0_hat (torch.Tensor): [*, N_res, 3, 3]
                The estimated orientation in rotation matrix form at time_step = 0.
            rot_mask (torch.Tensor): [*, N_res]
                The mask of the rotation matrix.
            t (torch.Tensor): [*,]
        
        Returns:
            torch.Tensor: [*, N_res, 3, 3]
                The orientation in rotation matrix form at time_step = t-1.
        """
        self._log('In SO3Diffuser.reverse_sample: start')  # DEBUG1
        self._log('rot_t', rot_t)  # DEBUG1
        self._log('rot_t', rot_0_hat)  # DEBUG1
        self._log('t', t)  # DEBUG1
        self._log('rot_mask', rot_mask)  # DEBUG1

        if rot_mask is None:
            rot_mask = torch.ones(rot_t.shape[:-2], dtype=torch.bool)

        if self.reverse_strategy == 'hat_and_noise':

            #, not adding additional loss for step t = 1 -> t = 0
            #. [*,] -> [*, N_res]
            # sigma = self.sigma_interval_schedule[t] * self.reverse_parm['noise_scale']
            # NOTE 2023.08.25: in torch 2.0, `t` need to be manually passed to cpu
            t_cpu = t.cpu()
            sigma = self.sigma_interval_schedule[t_cpu] * self.reverse_noise_schedule[t_cpu]
            sigma[t_cpu == 1] = 0
            sigma = sigma[..., None].expand(rot_t.shape[:-2])
            # if self.log: print("sigma", sigma.shape, sigma.dtype)  # DEBUG
            # print('t', t, 'sigma', sigma)

            #. [*, N_res, 3]
            noise_vec = self.sample_vec(sigma).to(rot_t.device)

            #. [*, N_res, 3, 3]
            noise_rot = so3.so3_Exp(noise_vec)

            
            #. [*, N_res, 3, 3]
            rot = rot_t * (~rot_mask[..., None, None]) + rot_0_hat @ noise_rot * rot_mask[..., None, None]
        
        else:
            raise NotImplementedError

        if self.debug:
            return rot, noise_rot
        return rot

    
    def _load_cache(self):
        print('Loading cache from %s' % (self.config.cache_dir))

        self.omega_cache = np.concatenate([np.array([0]), np.load(os.path.join(self.config.cache_dir, 'omega_ary.npy'))])
        self.sigma_cache = np.concatenate([np.array([0]), np.load(os.path.join(self.config.cache_dir, 'sigma_ary%s.npy' % (self.config.suffix)))])
        self.cdf_cache = np.load(os.path.join(self.config.cache_dir, 'cdf_ary%s.npy' % (self.config.suffix)))
        

        n_sigma, n_omega = self.cdf_cache.shape
        self.cdf_cache = np.concatenate([np.zeros((n_sigma, 1)), self.cdf_cache], axis = 1)
        self.cdf_cache = np.concatenate([np.ones((1, n_omega + 1)), self.cdf_cache], axis = 0)


    def _precompute_time_schedule(self):
        if self.config.schedule == 'linear':
            self.sigma_schedule = torch.cat([torch.tensor([0.]), torch.linspace(self.config.sigma_1, self.config.sigma_T, self.config.T)])
        elif self.config.schedule == 'log':
            self.sigma_schedule = torch.cat([torch.tensor([0.]), torch.exp(torch.linspace(np.log(self.config.sigma_1), np.log(self.config.sigma_T), self.config.T))])
        elif self.config.schedule == 'exp':
            self.sigma_schedule = torch.cat([torch.tensor([0.]), torch.log(torch.linspace(np.exp(self.config.sigma_1), np.exp(self.config.sigma_T), self.config.T))])
        else:
            raise NotImplementedError
        
        self.sigma_interval_schedule = torch.cat([torch.tensor([0.]), torch.sqrt(self.sigma_schedule[1:]**2 - self.sigma_schedule[:-1]**2)])
        if self.config.reverse_strategy == 'hat_and_noise':
            if self.config.hat_and_noise.noise_scale_schedule == 'linear':
                self.reverse_noise_schedule = torch.cat([torch.tensor([0.]), torch.linspace(self.config.hat_and_noise.noise_scale[0], self.config.hat_and_noise.noise_scale[1], self.config.T)])
            else:
                raise NotImplementedError
        # self.sigma_interval_schedule[1] = 0

    
    def _get_cdf_np(self, sigma):
        if np.isscalar(sigma):
            sigma = np.array(sigma)
        n_sigma = len(self.sigma_cache)

        #. [*size,] -> [n_sample,]
        sigma_reshaped = sigma.reshape(-1)

        #. sigma_ary -> [1, n_sigma]
        #. sigma -> [n_sample, 1]
        #. sample_idx -> [n_sample, n_sigma] -> [n_sample,]
        sigma_idx = np.sum(self.sigma_cache[None] < sigma_reshaped[..., None], axis=-1).clip(max = n_sigma - 1)

        #. [*size, n_omega]
        cdf_ary = np.take_along_axis(self.cdf_cache, sigma_idx[..., None].repeat(self.cdf_cache.shape[1], axis = -1), axis=0).reshape(*sigma.shape, -1)

        return cdf_ary

    
    def sample_omega_np(self, sigma):
        """Sample omega from IGSO3 distribution

        Args:
            sigma: ndarray with arbitrary shape
                Current sigma
            sigma_ary: [n_sigma]
                Array of sigma
            cdf_ary: [n_sigma, n_omega]
                Array of cdf
            omega_ary: [n_omega]
                Array of omega to sample
        
        Returns:
            omega: sahpe of `sigma`
        """
        if np.isscalar(sigma):
            sigma = np.array(sigma)
            # print(sigma.shape)
        n_sigma = len(self.sigma_cache)
        # print('sigma', sigma)

        #. [*size,] -> [n_sample,]
        sigma_reshaped = sigma.reshape(-1)

        #. sigma_ary -> [1, n_sigma]
        #. sigma -> [n_sample, 1]
        #. sample_idx -> [n_sample, n_sigma] -> [n_sample,]
        # sigma_idx = np.sum(self.sigma_cache[None] <= sigma_reshaped[..., None], axis=-1).clip(max = n_sigma - 1)
        sigma_idx = np.sum(self.sigma_cache[None] < sigma_reshaped[..., None], axis=-1).clip(max = n_sigma - 1)

        #. [n_sample,]
        rd = np.random.rand(*sigma_idx.shape)
        # print(rd)

        # print(sigma_idx.shape)

        #. [n_sample,] -> [*size,]
        sample = np.array([np.interp(rd[i], self.cdf_cache[sigma_idx[i]], self.omega_cache) for i in range(len(sigma_idx))])
        sample = sample.reshape(sigma.shape)

        # print('omega', sample)

        return sample


    def sample_vec_np(self, sigma):
        """Sample the rotation vector with omega sampled from IGSO3 distribution

        Args:
            sigma: ndarray with arbitrary shape
                Current sigma
            sigma_ary: [n_sigma]
                Array of sigma
            cdf_ary: [n_sigma, n_omega]
                Array of cdf
            omega_ary: [n_omega]
                Array of omega to sample
        
        Returns:
            vec: shape of [*sigma.shape, 3]

        """
        #. [*sigma.shape, 3]
        vec = np.random.randn(*(*sigma.shape, 3))
        vec /= np.linalg.norm(vec, axis=-1, keepdims=True)

        # print(vec)
        #. [*sigma.shape, 1]
        omega =  self.sample_omega_np(sigma)[..., None]
        # print(omega)
        return vec * omega
    
    
    def sample_omega(self, sigma):
        """Sample omega from IGSO3 distribution

        Args:
            sigma: [ndarray, tensor, scalr] with arbitrary shape
                Current sigma
            sigma_ary: [n_sigma]
                Array of sigma
            cdf_ary: [n_sigma, n_omega]
                Array of cdf
            omega_ary: [n_omega]
                Array of omega to sample
        
        Returns:
            omega: sahpe of `sigma`
        """
        if isinstance(sigma, torch.Tensor):
            sigma = sigma.cpu().numpy()
        
        sample = self.sample_omega_np(sigma)

        return torch.from_numpy(sample).float()


    def sample_vec(self, sigma):
        """Sample the rotation vector with omega sampled from IGSO3 distribution

        Args:
            sigma: [ndarray, tensor, scalr] with arbitrary shape
                Current sigma
            sigma_ary: [n_sigma]
                Array of sigma
            cdf_ary: [n_sigma, n_omega]
                Array of cdf
            omega_ary: [n_omega]
                Array of omega to sample
        
        Returns:
            vec: shape of [*sigma.shape, 3]

        """
        if isinstance(sigma, torch.Tensor):
            sigma = sigma.cpu().numpy()

        vec = self.sample_vec_np(sigma)
        return torch.from_numpy(vec).float()
