import torch
from torch import nn
from torch_scatter import scatter_add, scatter_mean
from torch_geometric.data import Data
import numpy as np
from numpy import pi as PI
from tqdm.auto import tqdm

from utils.chem import BOND_TYPES
from ..common import MultiLayerPerceptron, assemble_atom_pair_feature, generate_symmetric_edge_noise, extend_graph_order_radius
from ..encoder import SchNetEncoder, get_edge_encoder
from ..geometry import get_distance, get_angle, get_dihedral, convert_score_d


class ConformationEpsNetwork(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config

        """
        edge_encoder:  Takes both edge type and edge length as input and outputs a vector
        [Note]: node embedding is done in SchNetEncoder
        """
        self.edge_encoder = get_edge_encoder(config)

        """
        The graph neural network that extracts node-wise features.
        """
        self.model = SchNetEncoder(
            hidden_channels=config.hidden_dim,
            num_filters=config.hidden_dim,
            num_interactions=config.num_convs,
            edge_channels=self.edge_encoder.out_channels,
            cutoff=config.cutoff,
            smooth=config.smooth_conv,
        )

        """
        `output_mlp` takes a mixture of two nodewise features and edge features as input and outputs 
            gradients w.r.t. edge_length (out_dim = 1).
        """
        self.grad_global_dist_mlp = MultiLayerPerceptron(
            2 * config.hidden_dim,
            [config.hidden_dim, config.hidden_dim // 2, 1], 
            activation=config.mlp_act
        )

        self.grad_local_dist_mlp = MultiLayerPerceptron(
            2 * config.hidden_dim,
            [config.hidden_dim, config.hidden_dim // 2, 1], 
            activation=config.mlp_act
        )

        sigmas = torch.tensor(
            np.exp(np.linspace(np.log(config.sigma_begin), np.log(config.sigma_end),
                               config.num_noise_level)), dtype=torch.float32)
        self.sigmas = nn.Parameter(sigmas, requires_grad=False) # (num_noise_level)

    def forward(self, atom_type, pos, bond_index, bond_type, batch, sigma_edge, edge_index=None, edge_type=None, edge_length=None, return_edges=False, extend_order=True):
        """
        Args:
            atom_type:  Types of atoms, (N, ).
            bond_index: Indices of bonds (not extended, not radius-graph), (2, E).
            bond_type:  Bond types, (E, ).
            batch:      Node index to graph index, (N, ).
        """
        N = atom_type.size(0)
        if edge_index is None or edge_type is None or edge_length is None:
            edge_index, edge_type = extend_graph_order_radius(
                num_nodes=N,
                pos=pos,
                edge_index=bond_index,
                edge_type=bond_type,
                batch=batch,
                order=self.config.edge_order,
                cutoff=self.config.cutoff,
                extend_order=extend_order,
            )
            edge_length = get_distance(pos, edge_index).unsqueeze(-1)   # (E, 1)
        local_edge_mask = is_local_edge(edge_type)  # (E, )

        # Encoding
        edge_attr = self.edge_encoder(
            edge_length=edge_length,
            edge_type=edge_type
        )   # Embed edges
        node_attr = self.model(
            z=atom_type,
            edge_index=edge_index,
            edge_length=edge_length,
            edge_attr=edge_attr,
        )

        # Assemble pairwise features
        h_pair = assemble_atom_pair_feature(
            node_attr=node_attr,
            edge_index=edge_index,
            edge_attr=edge_attr,
        )    # (E, 2H)

        # Score of edges (radius graph, global)
        score_d_global = self.grad_global_dist_mlp(h_pair) * (1.0 / sigma_edge)    # (E, 1)

        # Score of edges (bond graph, local)
        score_d_local = self.grad_local_dist_mlp(h_pair) * (1.0 / sigma_edge) * local_edge_mask.unsqueeze(-1) # (E, 1)

        if return_edges:
            return score_d_global, score_d_local, edge_index, edge_type, edge_length
        else:
            return score_d_global, score_d_local

    def get_loss(self, atom_type, pos, bond_index, bond_type, batch, num_nodes_per_graph, num_graphs, anneal_power=2.0, return_unreduced_loss=False):
        N = atom_type.size(0)
        edge_index, edge_type = extend_graph_order_radius(
            num_nodes=N,
            pos=pos,
            edge_index=bond_index,
            edge_type=bond_type,
            batch=batch,
            order=self.config.edge_order,
            cutoff=self.config.cutoff * 1.5,    # "*1.5": learns to push points out of the boundary
        )
        edge_length = get_distance(pos, edge_index).unsqueeze(-1)   # (E, 1)

        node2graph = batch
        edge2graph = node2graph[edge_index[0]]
        local_edge_mask = is_local_edge(edge_type)  # (E, )

        # Sample noise levels (sigmas)
        noise_levels = self.sigmas[torch.randint(0, self.sigmas.size(0), size=(num_graphs, ))]  # (G, )
        sigmas_edge = noise_levels[edge2graph].unsqueeze(-1)  # (E, 1)

        # Perturb edge (global)
        d_gt = edge_length
        d_noise = generate_symmetric_edge_noise(num_nodes_per_graph, edge_index, edge2graph, edge_length.device)
        d_perturbed = d_gt + d_noise * sigmas_edge
        if self.config.edge_encoder == 'gaussian':
            # Distances must be greater than 0 
            d_sgn = torch.sign(d_perturbed)
            d_perturbed = torch.clamp(d_perturbed * d_sgn, min=0.01, max=float('inf'))
        d_target = 1. / (sigmas_edge ** 2) * (d_gt - d_perturbed)   # (E, 1), denoising direction

        # Estimate scores
        d_score_global, d_score_local = self(
            atom_type = atom_type,
            pos = pos,
            bond_index = None,
            bond_type = None,
            batch = batch,
            sigma_edge = sigmas_edge,
            edge_index = edge_index,    # `edge_index` is pre-computed, so `bond_index` will not be used
            edge_type = edge_type,      # `edge_type` is pre-computed, so `bond_type` will not be used
            edge_length = d_perturbed,  # Input the perturbed distances 
            return_edges=False,
        )

        # Loss for edge score
        loss_d_global = 0.5 * ((d_score_global - d_target) ** 2) * (sigmas_edge ** anneal_power)
        loss_d_global = torch.where(
            torch.logical_or(d_perturbed <= self.config.cutoff, local_edge_mask.unsqueeze(-1)), 
            loss_d_global, 
            torch.zeros_like(loss_d_global)
        )
        
        loss_d_local = 0.5 * ((d_score_local - d_target) ** 2) * (sigmas_edge ** anneal_power) * local_edge_mask.unsqueeze(-1)
        
        loss_global = scatter_add(loss_d_global.squeeze(), edge2graph)  # (G, 1)
        loss_local = scatter_add(loss_d_local.squeeze(), edge2graph)    # (G, 1)
        loss = loss_global + loss_local

        if return_unreduced_loss:
            return loss, loss_global, loss_local
        else:
            return loss

    def langevin_dynamics_sample(self, atom_type, pos_init, bond_index, bond_type, batch, extend_order, n_steps_each=100, step_lr=0.0000024, clip=1e6, min_sigma=0, w_radius=1.0, w_reg=1.0):
        sigmas = self.sigmas
        pos_traj = []

        with torch.no_grad():
            pos = pos_init
            for i, sigma in enumerate(tqdm(sigmas, desc='sample')):
                if sigma < min_sigma:
                    break
                step_size = step_lr * (sigma / sigmas[-1]) ** 2
                # print(step_size)
                # if i >= len(sigmas) - 3:
                #     n_steps = n_steps_each * 3
                # else:
                #     n_steps = n_steps_each
                n_steps = n_steps_each
                for step in range(n_steps):
                    score_d_global, score_d_local, edge_index, edge_type, edge_length = self(
                        atom_type=atom_type,
                        pos=pos,
                        bond_index=bond_index,
                        bond_type=bond_type,
                        batch=batch,
                        sigma_edge=sigma,
                        return_edges=True,
                        extend_order=extend_order,
                    )
                    # if step % 50 == 0:
                    #     print([v.item() for v in (edge_length[is_bond(edge_type)].mean(), score_d_local.mean(), score_d_local.min(), score_d_local.max())])
                    reg_d_bond = regularize_bond_length(edge_type, edge_length)
                    # score_d = reweight_score_d(score_d_model + reg_d_bond, edge_type=edge_type, edge_length=edge_length)
                    mask = is_local_edge(edge_type).unsqueeze(-1).float()
                    score_d = score_d_local * mask + score_d_global * (1 - mask) + reg_d_bond * w_reg
                    score_d = reweight_score_d(score_d, edge_type=edge_type, edge_length=edge_length, w_radius=w_radius)
                    score_pos = convert_score_d(score_d, pos, edge_index, edge_length)
                    # score_pos = torch.clamp(score_pos, min=0.0, max=clip)
                    score_pos = clip_norm(score_pos, clip)
                    noise = torch.randn_like(pos) * torch.sqrt(step_size*2)
                    pos_next = pos + step_size * score_pos + noise
                    del pos
                    pos = center_pos(pos_next, batch)
                    pos_traj.append(pos_next.clone().cpu())
            
        return pos, pos_traj



def is_bond(edge_type):
    return torch.logical_and(edge_type < len(BOND_TYPES), edge_type > 0)


def is_angle_edge(edge_type):
    return edge_type == len(BOND_TYPES) + 1 - 1


def is_dihedral_edge(edge_type):
    return edge_type == len(BOND_TYPES) + 2 - 1


def is_radius_edge(edge_type):
    return edge_type == 0


def is_local_edge(edge_type):
    return edge_type > 0


def regularize_bond_length(edge_type, edge_length, rng=5.0):
    mask = is_bond(edge_type).float().reshape(-1, 1)
    d = -torch.clamp(edge_length - rng, min=0.0, max=float('inf')) * mask
    return d


def reweight_score_d(score_d, edge_type, edge_length, w_bond=1.0, w_angle=1.0, w_dihedral=1.0, w_radius=1.0, cutoff=10.0):
    mask_bond = is_bond(edge_type).float().reshape(-1, 1)
    mask_angle = is_angle_edge(edge_type).float().reshape(-1, 1)
    mask_dihedral = is_dihedral_edge(edge_type).float().reshape(-1, 1)
    mask_radius = is_radius_edge(edge_type).float().reshape(-1, 1)

    weight_bond = (1 - mask_bond) + mask_bond * w_bond
    weight_angle = (1 - mask_angle) + mask_angle * w_angle
    weight_dihedral = (1 - mask_dihedral) + mask_dihedral * w_dihedral

    C = 0.5 * (torch.cos(edge_length * PI / cutoff) + 1.0)
    C = C * (edge_length <= cutoff) * (edge_length >= 0.0)     # Modification: cutoff
    weight_radius = (1 - mask_radius) + mask_radius * C.reshape(-1, 1) * w_radius

    return score_d * weight_bond * weight_angle * weight_dihedral * weight_radius
    

def center_pos(pos, batch):
    pos_center = pos - scatter_mean(pos, batch, dim=0)[batch]
    return pos_center


def clip_norm(vec, limit, p=2):
    norm = torch.norm(vec, dim=-1, p=2, keepdim=True)
    denom = torch.where(norm > limit, limit / norm, torch.ones_like(norm))
    return vec * denom
