# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT

import logging
from typing import Optional, Literal, Dict

import torch
import torch.nn as nn
from dgl import DGLGraph
from torch import Tensor

from se3_transformer.model.basis import get_basis, update_basis_with_fused
from se3_transformer.model.layers.attention import AttentionBlockSE3
from se3_transformer.model.layers.convolution import ConvSE3, ConvSE3FuseLevel
from se3_transformer.model.layers.linear import LinearSE3
from se3_transformer.model.layers.norm import NormSE3
from se3_transformer.model.layers.pooling import GPooling
from se3_transformer.runtime.utils import str2bool
from se3_transformer.model.fiber import Fiber


class Sequential(nn.Sequential):
    """ Sequential module with arbitrary forward args and kwargs. Used to pass graph, basis and edge features. """
    def forward(self, input, *args, **kwargs):
        for module in self:
            input = module(input, *args, **kwargs)
        return input


def get_populated_edge_features(relative_pos: Tensor, edge_features: Optional[Dict[str, Tensor]] = None):
    """ Add (norm of) relative positions to existing edge features """
    edge_features = edge_features.copy() if edge_features else {}
    r = relative_pos.norm(dim=-1, keepdim=True)
    if '0' in edge_features:
        edge_features['0'] = torch.cat([edge_features['0'], r[..., None]], dim=1)
    else:
        edge_features['0'] = r[..., None]

    return edge_features


class SE3Transformer(nn.Module):
    def __init__(self,
                 num_layers: int,
                 fiber_in: Fiber,
                 fiber_hidden: Fiber,
                 fiber_out: Fiber,
                 num_heads: int,
                 channels_div: int,
                 fiber_edge: Fiber = Fiber({}),
                 return_type: Optional[int] = None,
                 pooling: Optional[Literal['avg', 'max']] = None,
                 norm: bool = True,
                 use_layer_norm: bool = True,
                 tensor_cores: bool = False,
                 low_memory: bool = False,
                 **kwargs):
        """
        :param num_layers:          Number of attention layers
        :param fiber_in:            Input fiber description
        :param fiber_hidden:        Hidden fiber description
        :param fiber_out:           Output fiber description
        :param fiber_edge:          Input edge fiber description
        :param num_heads:           Number of attention heads
        :param channels_div:        Channels division before feeding to attention layer
        :param return_type:         Return only features of this type
        :param pooling:             'avg' or 'max' graph pooling before MLP layers
        :param norm:                Apply a normalization layer after each attention block
        :param use_layer_norm:      Apply layer normalization between MLP layers
        :param tensor_cores:        True if using Tensor Cores (affects the use of fully fused convs, and padded bases)
        :param low_memory:          If True, will use slower ops that use less memory
        """
        super().__init__()
        print("............................SE3_Transformer..................................")
        self.num_layers = num_layers
        self.fiber_edge = fiber_edge
        self.num_heads = num_heads
        self.channels_div = channels_div
        self.return_type = return_type
        self.pooling = pooling
        self.max_degree = max(*fiber_in.degrees, *fiber_hidden.degrees, *fiber_out.degrees)
        self.tensor_cores = tensor_cores
        self.low_memory = low_memory
        if tensor_cores:
            print("Using Tensor Cores")

        if low_memory:
            self.fuse_level = ConvSE3FuseLevel.NONE
        else:
            # Fully fused convolutions when using Tensor Cores (and not low memory mode)
            # It is the max level of fusion permitted. Changes inside to lower if degrees do not fit.
            self.fuse_level = ConvSE3FuseLevel.FULL if tensor_cores else ConvSE3FuseLevel.PARTIAL
        
        #print("Max Fuse Level:", self.fuse_level)
        #print("Transformer features Return type:", return_type)
        #print("Max Degree:",self.max_degree)
        graph_modules = []
        for i in range(num_layers):
            graph_modules.append(AttentionBlockSE3(fiber_in=fiber_in,
                                                   fiber_out=fiber_hidden,
                                                   fiber_edge=fiber_edge,
                                                   num_heads=num_heads,
                                                   channels_div=channels_div,
                                                   use_layer_norm=use_layer_norm,
                                                   max_degree=self.max_degree,
                                                   fuse_level=self.fuse_level,
                                                   low_memory=low_memory))
            if norm:
                graph_modules.append(NormSE3(fiber_hidden))
            fiber_in = fiber_hidden
        # output of attention block is a fibber_hidden fiber (if num_layers>0).
        graph_modules.append(ConvSE3(fiber_in=fiber_in,
                                     fiber_out=fiber_out,
                                     fiber_edge=fiber_edge,
                                     self_interaction=True,
                                     use_layer_norm=use_layer_norm,
                                     max_degree=self.max_degree))
        self.graph_modules = Sequential(*graph_modules)

        if pooling is not None:
            print("Time to pool")
            assert return_type is not None, 'return_type must be specified when pooling'
            self.pooling_module = GPooling(pool=pooling, feat_type=return_type)

    def forward(self, graph: DGLGraph, node_feats: Dict[str, Tensor],
                edge_feats: Optional[Dict[str, Tensor]] = None,
                basis: Optional[Dict[str, Tensor]] = None):
        #print("SE3 Transformer forward:")
        # Compute bases in case they weren't precomputed as part of the data loading
        basis = basis or get_basis(graph.edata['rel_pos'], max_degree=self.max_degree, compute_gradients=False,
                                   use_pad_trick=self.tensor_cores and not self.low_memory,
                                   amp=torch.is_autocast_enabled())
        #for key,val in basis.items():
        #    print('[{}]:{}'.format(key,val.shape))
        # Add fused bases (per output degree, per input degree, and fully fused) to the dict
        basis = update_basis_with_fused(basis, self.max_degree, use_pad_trick=self.tensor_cores and not self.low_memory,
                                        fully_fused=self.fuse_level == ConvSE3FuseLevel.FULL)
        #for key,val in basis.items():
        #    print('[{}]:{}'.format(key,val.shape))
        if edge_feats:
            print("Edge features in SE3Transf bef:",[edge_feats[d].shape for d in edge_feats])
        edge_feats = get_populated_edge_features(graph.edata['rel_pos'], edge_feats)
        #print("Edge features in SE3Transf aft:",[edge_feats[d].shape for d in edge_feats if edge_feats])
        
        node_feats = self.graph_modules(node_feats, edge_feats, graph=graph, basis=basis)
        
        if self.pooling is not None: 
            return self.pooling_module(node_feats, graph=graph)

        if self.return_type is not None:
            return node_feats[str(self.return_type)]
        
        return node_feats

    @staticmethod
    def add_argparse_args(parser):
        parser.add_argument('--num_layers', type=int, default=8, #8
                            help='Number of stacked Transformer layers')
        parser.add_argument('--num_heads', type=int, default=2, #8
                            help='Number of heads in self-attention')
        parser.add_argument('--channels_div', type=int, default=2,
                            help='Channels division before feeding to attention layer')
        parser.add_argument('--pooling', type=str, default=None, const=None, nargs='?', choices=['max', 'avg'],
                help='Type of graph pooling')
        parser.add_argument('--norm', type=str2bool, nargs='?', const=True, default=False,
                            help='Apply a normalization layer after each attention block')
        parser.add_argument('--use_layer_norm', type=str2bool, nargs='?', const=True, default=False,
                            help='Apply layer normalization between MLP layers')
        parser.add_argument('--low_memory', type=str2bool, nargs='?', const=True, default=False,
                            help='If true, will use fused ops that are slower but that use less memory '
                                 '(expect 25 percent less memory). '
                                 'Only has an effect if AMP is enabled on Volta GPUs, or if running on Ampere GPUs')

        return parser

class IOSE3Transformer(nn.Module):
    def __init__(self,
                 i_fiber_in: Fiber,
                 i_fiber_out: Fiber,
                 i_fiber_edge: Fiber,
                 o_fiber_in: Fiber,
                 o_fiber_out: Fiber,
                 o_fiber_edge: Fiber,
                 i_num_degrees: int,
                 i_num_channels: int,
                 o_num_degrees:int,
                 o_num_channels:int,
                 output_dim: Optional[int] = 1,
                 **kwargs):
        super().__init__()
        print("............................SE3_Transformer_IO.............................")

        
        #kwargs['pooling'] = kwargs['pooling'] or 'max'
         
        kwargs['pooling'] = None
        kwargs['num_layers'] = kwargs['i_num_layers']
        self.i_transformer = SE3Transformer(
            fiber_in=i_fiber_in,
            fiber_hidden=Fiber.create(i_num_degrees, i_num_channels),
            fiber_out=i_fiber_out,
            fiber_edge= i_fiber_edge,
            return_type=None,
            **kwargs
        )
        i_n_out_features = i_fiber_out.num_features
        
        kwargs['num_layers'] = kwargs['o_num_layers']
        self.o_transformer = SE3Transformer(
            fiber_in=i_fiber_out,
            fiber_hidden=Fiber.create(o_num_degrees, o_num_channels),
            fiber_out=o_fiber_out,
            fiber_edge=o_fiber_edge,
            return_type=None,
            **kwargs
        )
        o_n_out_features = o_fiber_out.num_features
        self.mlp = nn.Sequential(
            nn.Linear(o_n_out_features, o_n_out_features),
            nn.ReLU(),
            nn.Linear(o_n_out_features, output_dim),
            #nn.Sigmoid()
        )   
        self.linear = LinearSE3(o_fiber_in, i_fiber_out)

    def forward(self, batch, basis=None, for_flag='i_forw'):

        edge_feats = None

        if for_flag == 'i_forw':
            i_graph = batch['input_graph']
            node_feats = batch['node_feats']
            i_feats = self.i_transformer(i_graph, node_feats, edge_feats, basis)
            #print("i_feats:", [(d,i_feats[d].shape) for d in i_feats.keys()]) 
            return i_feats
        else:
            i_feats = batch['i_feats']
            o_graph = batch[batch['forw_key']+'_graph']
            o_feats = batch[batch['forw_key']+'_feats']
            
            o_feats = self.linear(o_feats)
            
            io_feats = {}
            o_feats_batch_size=0
            for degree in i_feats.keys():
                try:
                    io_feats[degree] = torch.cat((i_feats[degree],o_feats[degree]))
                    o_feats_batch_size=o_feats[degree].shape[0]
                except:
                    print("Exception In IOSE3TRANSFORMER FORWARD!")
                    continue
            #print("io_feats:", [(d,io_feats[d].shape) for d in io_feats.keys()]) 
             
            io_feats = self.o_transformer(o_graph, io_feats, edge_feats, basis=None)
                        
            #.squeeze(-1)
            #print("io_feats:", io_feats)
            #print("io_feats:", [(d,io_feats[d].shape) for d in io_feats.keys()]) 
            
            feats=io_feats['0'][-o_feats_batch_size:].squeeze(-1)
            #glabal pooling or mlp.
            #print("Features before mlp", feats.shape)
            y = self.mlp(feats).squeeze(-1) 
            
            return y

    @staticmethod
    def add_argparse_args(parent_parser):
        parser = parent_parser.add_argument_group("Model architecture")
        SE3Transformer.add_argparse_args(parser)
        parser.add_argument('--num_degrees',
                            help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
                            type=int, default=2)
        parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=16)
        parser.add_argument('--i_num_layers', help='Number of layers for i_transformer', type=int, default=4)
        parser.add_argument('--o_num_layers', help='Number of layers for o_transformer', type=int, default=4)
        return parent_parser   


class SE3TransformerPooled(nn.Module):
    def __init__(self,
                 fiber_in: Fiber,
                 fiber_out: Fiber,
                 fiber_edge: Fiber,
                 num_degrees: int,
                 num_channels: int,
                 output_dim: int,
                 **kwargs):
        super().__init__()
        print("............................SE3_Transformer Pooled.............................")
        kwargs['pooling'] = kwargs['pooling'] or 'max'
        print("Fibers: In, Out, Edge, Hidden")
        print(fiber_in,fiber_out,fiber_edge, Fiber.create(num_degrees, num_channels))
        self.transformer = SE3Transformer(
            fiber_in = fiber_in,
            fiber_hidden = Fiber.create(num_degrees, num_channels),
            fiber_out = fiber_out,
            fiber_edge = fiber_edge,
            return_type = 0,
            **kwargs
        )
        
        n_out_features = fiber_out.num_features
        self.mlp = nn.Sequential(
            nn.Linear(n_out_features, n_out_features),
            nn.ReLU(),
            nn.Linear(n_out_features, output_dim)
        )

    def forward(self, graph, node_feats, edge_feats, basis=None):
        print("SE3 Pooled forward............")
        print("Node features:",[node_feats[d].shape for d in node_feats])
        print("Edge features:",[edge_feats[d].shape for d in edge_feats])
        feats = self.transformer(graph, node_feats, edge_feats, basis).squeeze(-1)
        y = self.mlp(feats).squeeze(-1) 
        #general mlp is valid only for type '0' features that were graph-wise pooled with 'max'. For 'avg' features? 
        #ConvSE3 now returns invariant features.
        return y

    @staticmethod
    def add_argparse_args(parent_parser):
        parser = parent_parser.add_argument_group("Model architecture")
        SE3Transformer.add_argparse_args(parser)
        parser.add_argument('--num_degrees',
                            help='Number of degrees to use. Hidden features will have types [0, ..., num_degrees - 1]',
                            type=int, default=1)
        parser.add_argument('--num_channels', help='Number of channels for the hidden features', type=int, default=32)
        return parent_parser
