# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT

import argparse
import ctypes
import logging
import os
import random
from functools import wraps
from typing import Union, List, Dict

import numpy as np
import torch
#import torch.distributed as dist
from torch import Tensor
import se3_transformer.runtime.environment as envir

def hasnone(arg):
    if isinstance(arg,dict):
        return sum([hasnone(arg[key]) for key in arg.keys()])
    if isinstance(arg,list):
        return sum([hasnone(el) for el in arg])
    elif isinstance(arg,Tensor):
        return torch.isnan(arg).sum().item()
    else:
        return sum(elem is None for elem in arg)

def cfg_update(cfg,args,instr=''):
    '''
    Update cfg dictionary using args dictionary.
    cfg is nested args is flattened.
    if cfg keys concatenated in depth with instr_ are keys in args, we update cfg.
    '''
    for k,v in cfg.items():
        key = instr+ (len(instr)>0)*'_' + k
        if isinstance(v, dict):
            cfg_update(v,args,key)
        elif key in args:
            cfg[k]=args[key]
 

def aggregate_residual(feats1, feats2, method: str):
    """ Add or concatenate two fiber features together. If degrees don't match, will use the ones of feats2. """
    if method in ['add', 'sum']:
        return {k: (v + feats1[k]) if k in feats1 else v for k, v in feats2.items()}
    elif method in ['cat', 'concat']:
        return {k: torch.cat([v, feats1[k]], dim=1) if k in feats1 else v for k, v in feats2.items()}
    else:
        raise ValueError('Method must be add/sum or cat/concat')


def degree_to_dim(degree: int) -> int:
    return 2 * degree + 1


def unfuse_features(features: Tensor, degrees: List[int]) -> Dict[str, Tensor]:
    return dict(zip(map(str, degrees), features.split([degree_to_dim(deg) for deg in degrees], dim=-1)))


def str2bool(v: Union[bool, str]) -> bool:
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def to_cuda(x):
    """ Try to convert a Tensor, a collection of Tensors or a DGLGraph to CUDA """
    if isinstance(x, Tensor):
        return x.cuda(non_blocking=True)
    elif isinstance(x, tuple):
        return (to_cuda(v) for v in x)
    elif isinstance(x, list):
        return [to_cuda(v) for v in x]
    elif isinstance(x, dict):
        return {k: to_cuda(v) for k, v in x.items()}
    elif isinstance(x,int):
        return x
    elif isinstance(x,str):
        return x
    else:
        # DGLGraph or other objects
        return x.to(device=torch.cuda.current_device())


def get_local_rank() -> int:
    if envir.arg.envir=='cluster':
        return int(os.environ.get('LOCAL_RANK', 0))
    else:
        import smdistributed.dataparallel.torch.distributed as dist
        if dist.is_initialized():
            return dist.get_local_rank()
        else:
            return 0
def get_rank() -> int:
    if envir.arg.envir=='cluster':
        import torch.distributed as dist
        if dist.is_initialized():
            return dist.get_rank()
        else:
            return 0
    else:
        import smdistributed.dataparallel.torch.distributed as dist
        if dist.is_initialized():
            return dist.get_rank()
        else:
            return 0

def init_distributed() -> bool:
    print("------------------------------")
    print("-----------------------------")

    print(envir.arg.envir)
    if envir.arg.envir=='cluster':
        #global dist
        import torch.distributed as dist
        
        world_size = int(os.environ.get('WORLD_SIZE', 1))
        distributed = world_size > 1
        if distributed:
            backend = 'nccl' if torch.cuda.is_available() else 'gloo'
            dist.init_process_group(backend=backend, init_method='env://')
            if backend == 'nccl':
                torch.cuda.set_device(get_local_rank())
            else:
                logging.warning('Running on CPU only!')
            assert torch.distributed.is_initialized()
        return distributed
    else:
        #global dist 
        import smdistributed.dataparallel.torch.distributed as dist
        return dist.init_process_group()


def increase_l2_fetch_granularity():
    # maximum fetch granularity of L2: 128 bytes
    _libcudart = ctypes.CDLL('libcudart.so')
    # set device limit on the current device
    # cudaLimitMaxL2FetchGranularity = 0x05
    pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
    _libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
    _libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
    assert pValue.contents.value == 128


def seed_everything(seed):
    seed = int(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def rank_zero_only(fn):
    @wraps(fn)
    def wrapped_fn(*args, **kwargs):
        if not dist.is_initialized() or dist.get_rank() == 0:
            return fn(*args, **kwargs)

    return wrapped_fn


def using_tensor_cores(amp: bool) -> bool:
    major_cc, minor_cc = torch.cuda.get_device_capability()
    return (amp and major_cc >= 7) or major_cc >= 8
