'''
 *
 *     ICTP: Irreducible Cartesian Tensor Potentials
 *
 *        File:  representations.py
 *
 *     Authors: Deleted for purposes of anonymity 
 *
 *     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 * 
 * The software and its source code contain valuable trade secrets and shall be maintained in
 * confidence and treated as confidential information. The software may only be used for 
 * evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 * license agreement or nondisclosure agreement with the proprietor of the software. 
 * Any unauthorized publication, transfer to third parties, or duplication of the object or
 * source code---either totally or in part---is strictly prohibited.
 *
 *     Copyright (c) 2024 Proprietor: Deleted for purposes of anonymity
 *     All Rights Reserved.
 *
 * THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR 
 * IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY 
 * AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT 
 * DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION. 
 * 
 * NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 * IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE 
 * LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 * FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 * OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 * ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 * TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGES.
 * 
 * For purposes of anonymity, the identity of the proprietor is not given herewith. 
 * The identity of the proprietor will be given once the review of the 
 * conference submission is completed. 
 *
 * THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 *
'''
from typing import List, Any

import torch
import torch.nn as nn

from src.o3.linear_transform import LinearTransform
from src.o3.cartesian_harmonics import CartesianHarmonics
from src.nn.layers import RadialEmbeddingLayer, RealAgnosticResidualInteractionLayer, ProductBasisLayer
from src.utils.torch_geometric import Data


class CartesianMACE(nn.Module):
    """MACE (semi-)local atomic representation in Cartesian basis.

    Args:
        r_cutoff (float): Cutoff radius.
        n_basis (int): Number of radial basis functions.
        n_polynomial_cutoff (int): Polynomial order for the cutoff function. 
        n_species (int): Number of atomic species.
        n_hidden_feats (int): Number of hidden features.
        n_product_feats (int): Number of product basis features.
        coupled_product_feats (bool): If True, use mix features when computing the product basis.
        symmetric_product (bool): If True, exploit symmetry of the tensor product to reduce 
                                  the number of possible tensor contractions.
        l_max_hidden_feats (int): Maximal rotational order/rank of the Cartesian tensor for the hidden features.
        l_max_edge_attrs (int): Maximal rotational order/rank of the Cartesian tensor for the Cartesian harmonics.
        avg_n_neighbors (float): Avergae number of neighbors. It is used to normalize messages.
        correlation (int): Correlation order, i.e., number of contracted Cartesian tensors.
        n_interactions (int): Number of interaction layers.
        radial_MLP (List[int]): List of hidden features for the radial embedding network.
    """
    def __init__(self,
                 r_cutoff: float,
                 n_basis: int,
                 n_polynomial_cutoff: int,
                 n_species: int,
                 n_hidden_feats: int,
                 n_product_feats: int,
                 coupled_product_feats: bool,
                 symmetric_product: bool,
                 l_max_hidden_feats: int,
                 l_max_edge_attrs: int,
                 avg_n_neighbors: float,
                 correlation: int,
                 n_interactions: int,
                 radial_MLP: List[int],
                 **config: Any):
        super(CartesianMACE, self).__init__()
        self.n_hidden_feats = n_hidden_feats
        
        self.node_embedding = LinearTransform(in_l_max=0, out_l_max=0, in_features=n_species, out_features=n_hidden_feats)
        
        self.radial_embedding = RadialEmbeddingLayer(r_cutoff=r_cutoff, n_basis=n_basis, n_polynomial_cutoff=n_polynomial_cutoff)
        
        self.cartesian_harmonics = CartesianHarmonics(l_max_edge_attrs)
        
        if n_interactions == 1:
            l_max_out_feats = 0
        else:
            l_max_out_feats = l_max_hidden_feats
        
        inter = RealAgnosticResidualInteractionLayer(l_max_node_feats=0, l_max_edge_attrs=l_max_edge_attrs, l_max_target_feats=l_max_edge_attrs, 
                                                     l_max_hidden_feats=l_max_out_feats, n_basis=n_basis, n_species=n_species, in_features=n_hidden_feats, 
                                                     out_features=n_product_feats, avg_n_neighbors=avg_n_neighbors, radial_MLP=radial_MLP)
        
        self.interactions = torch.nn.ModuleList([inter])
        
        prod = ProductBasisLayer(l_max_node_feats=l_max_edge_attrs, l_max_target_feats=l_max_out_feats, in_features=n_product_feats, 
                                 out_features=n_hidden_feats, n_species=n_species, correlation=correlation, coupled_feats=coupled_product_feats, 
                                 symmetric_product=symmetric_product, use_sc=True)
        
        self.products = torch.nn.ModuleList([prod])
        
        for i in range(1, n_interactions):
            if i == n_interactions - 1:
                l_max_out_feats = 0
            else:
                l_max_out_feats = l_max_hidden_feats
            inter = RealAgnosticResidualInteractionLayer(l_max_node_feats=l_max_hidden_feats, l_max_edge_attrs=l_max_edge_attrs, l_max_target_feats=l_max_edge_attrs, 
                                                         l_max_hidden_feats=l_max_out_feats, n_basis=n_basis, n_species=n_species, in_features=n_hidden_feats, 
                                                         out_features=n_product_feats, avg_n_neighbors=avg_n_neighbors, radial_MLP=radial_MLP)
            self.interactions.append(inter)
            
            prod = ProductBasisLayer(l_max_node_feats=l_max_edge_attrs, l_max_target_feats=l_max_out_feats, in_features=n_product_feats, 
                                     out_features=n_hidden_feats, n_species=n_species, correlation=correlation, coupled_feats=coupled_product_feats,
                                     symmetric_product=symmetric_product, use_sc=True)
            self.products.append(prod)
                
    def forward(self, graph: Data) -> List[torch.Tensor]:
        """Computes node features.

        Args:
            graph (Data): Atomic graph data.

        Returns:
            torch.Tensor: Node features.
        """
        idx_i, idx_j = graph.edge_index[0, :], graph.edge_index[1, :]
        vectors = graph.positions.index_select(0, idx_i) - graph.positions.index_select(0, idx_j) - graph.shifts
        lengths = torch.norm(vectors, dim=-1, keepdim=True)
        
        node_feats = self.node_embedding(graph.node_attrs)
        edge_feats = self.radial_embedding(lengths)
        edge_attrs = self.cartesian_harmonics(vectors)  # vectors are normalized when computing Cartesian harmonics

        node_feats_list = []
        for interaction, product in zip(self.interactions, self.products):
            node_feats, sc = interaction(node_attrs=graph.node_attrs, node_feats=node_feats,
                                         edge_attrs=edge_attrs, edge_feats=edge_feats,
                                         idx_i=idx_i, idx_j=idx_j)
            
            node_feats = product(node_feats=node_feats, sc=sc, node_attrs=graph.node_attrs)
            
            node_feats_list.append(node_feats[:, :self.n_hidden_feats])

        return node_feats_list
