# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT

from typing import List
import dgl
import torch
import pathlib
from pathlib import Path
import os
import torch.nn as nn
#from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader
from tqdm import tqdm
from apex.optimizers import FusedAdam, FusedLAMB
from torch.optim import Optimizer

from se3_transformer.runtime import gpu_affinity

from se3_transformer.runtime.arguments import PARSER
from se3_transformer.runtime.callbacks import BaseCallback, ReconMetric
from se3_transformer.runtime.loggers import DLLogger, WandbLogger, LoggerCollection, TensorBoardLogger
from se3_transformer.runtime.utils import to_cuda, get_local_rank, using_tensor_cores,get_rank
from se3_transformer.data_loading.shapenet import _get_relative_pos
from im2mesh.common import make_3d_grid

from se3_transformer.runtime.utils import cfg_update
from im2mesh import config
import se3_transformer.runtime.environment as envir
from se3_transformer.runtime.utils import hasnone

'''
CFG={'method':'onet',
     'data':{'dataset':'Shapes3D',
             'path':'/Datasets/ShapeNet',
             'classes':None,
             'train_split':'train',
             'val_split':'val',
             'test_split':'test',
             'points_subsample':1024,
             'points_file':'points.npz',
             'voxels_file':None,
             'points_unpackbits':True,
             'input_type':'pointcloud',
             'with_transforms':False,
             'pointcloud_n':300,
             'pointcloud_target_n':1024,
             'pointcloud_noise':0.0,
             'pointcloud_file':'pointcloud.npz',
             'points_iou_file':'points.npz'},
     'model':{'use_camera':False}}
'''
def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]):
    """ Loads model, optimizer and epoch states from path """
    try:
        checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
        if isinstance(model, DistributedDataParallel):
            model.module.load_state_dict(checkpoint['state_dict'])
            print("SUCCESSFULLY LOADED PRETRAINED MODEL")
        else:
            model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        for callback in callbacks:
            callback.on_checkpoint_load(checkpoint)

        logging.info(f'Loaded checkpoint from {str(path)}')
        
        return checkpoint['epoch']
    except:
        print("UNSUCCESSFUL LOAD. INITILIZING MODEL")
        return 0

@torch.inference_mode()
def evaluate(model: nn.Module,
             dataloader: DataLoader,
             callbacks: List[BaseCallback],
             metric_name: str,
             args):

    
    mioucallback = [callback for callback in callbacks if callback.metric_name==metric_name][0]

    model.eval()
    device=torch.cuda.current_device()
    
    with torch.no_grad():
        for i, batch in tqdm(enumerate(dataloader), total=len(dataloader), unit='batch', desc=f'Evaluation',
                             leave=False, disable=(args.silent or get_local_rank() != 0)):
            
            
            #if get_local_rank()==0:
            #    import pdb
            #    pdb.set_trace()
            #torch.distributed.barrier()


            #*input, target = to_cuda(batch)
            batch=to_cuda(batch)
            batch['forw_key']='iou' 
            #for callback in callbacks:
            #    callback.on_batch_start()
            mioucallback.on_batch_start()

            batch_size=batch['points_iou.occ'].shape[0]
            '''  
            for batch_elem in range(batch_size):
                print("Percentage of ones",torch.sum(batch['points_iou.occ'][batch_elem]==1)/batch['points_iou.occ'].shape[1])
                print("Shape of ground truth",batch['points_iou.occ'].shape,batch['points_iou'].shape) 
            '''
            pred=[]
            

            val_mini_batch=batch['val_mini_batch']
            #val_mini_batch =100
            iou_mini_size=batch['points_iou'].shape[1]//val_mini_batch
          
            points_iou_occ_mini = batch['points_iou.occ']
            with torch.cuda.amp.autocast(enabled=args.amp):
                batch['i_feats']=model(batch,for_flag='i_forw')
            
            #iou_graph = dgl.graph((batch['iou_edges'][1],batch['iou_edges'][0]))

            for ind_mini in range(val_mini_batch):
               
                batch['points_iou.occ']=[]

                miniou_points = batch['points_iou'][:,iou_mini_size*ind_mini:iou_mini_size*(ind_mini+1)]
                

                batch_size = batch['inputs'].shape[0]
                pcld_size = batch['inputs'].shape[1]
                qpoints_size = miniou_points.shape[1]
                int_kmode=pcld_size if args.kmode=='full' else int(args.kmode)
                src_neighbors=batch['input_graph'].edges()[0].reshape(-1,int_kmode)

                diffs = batch['inputs'].unsqueeze(1) - miniou_points.unsqueeze(2)
                distance = torch.norm(diffs,dim=3,p=None)
                _,knn = torch.min(distance,dim=2)

                #inds = knn.reshape(-1) + diffs.shape[2] * torch.arange(diffs.shape[0]*diffs.shape[1],device=device)
                
                src_k1 = (knn + (torch.arange(batch_size,device=device)*pcld_size).reshape(-1,1)).reshape(-1)
                src=src_neighbors[src_k1].reshape(-1)
                
                relative_src=src-torch.arange(batch_size,device=device).repeat_interleave(int_kmode*qpoints_size)*pcld_size
                inds=relative_src+diffs.shape[2]*torch.arange(diffs.shape[0]*diffs.shape[1],device=device).repeat_interleave(int_kmode)
                #_,knn = torch.topk(distance,args.knnq,dim=2,largest=False)
                #src = (knn + (torch.arange(batch_size,device=device)*pcld_size).reshape(-1,1,1)).reshape(-1)
                dst = torch.arange(qpoints_size*batch_size,device=device).repeat_interleave(int_kmode) + batch_size*pcld_size


                iou_graph = dgl.graph((src,dst))
                
                iou_graph.ndata['pos'] = torch.cat((batch['inputs'].reshape(-1,3),miniou_points.reshape(-1,3))) 
                iou_graph.edata['rel_pos'] = _get_relative_pos(iou_graph) 
                #batch['iou_feats'] = {'1': miniou_points.reshape(-1,1,3)}
                #batch['iou_feats'] = {'1': diffs.reshape(-1,3)[inds,None]}
                avg_diffs=torch.mean(diffs.reshape(-1,3)[inds].reshape(-1,int_kmode,3),1,keepdim=True)
                batch['iou_feats'] = {'1': avg_diffs}
                
                batch['points_iou.occ'].append(points_iou_occ_mini[:,iou_mini_size*ind_mini:iou_mini_size*(ind_mini+1)]) 
                batch['iou_graph']= iou_graph
                

                with torch.cuda.amp.autocast(enabled=args.amp):
                    #batch['i_feats']=model(batch,index=0,for_flag='i_forw')
                    mini_pred=model(batch,for_flag='o_forw')
                    mini_pred = torch.sigmoid(mini_pred)    
                    pred.append(mini_pred.reshape(batch_size,-1))
                    
                    if hasnone(mini_pred)>0:
                        print("Something nasty happened in evaluate.")
                        state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
                        checkpoint = {'state_dict': state_dict,
                                      'batch': batch,
                                      'prediction': mini_pred,
                                     }
                        torch.save(checkpoint,Path(os.path.join(args.base_dir,'NanEvalCheckpoint.pth')))
                                   
                    #print("Nans Detected in OFORW:", torch.isnan(mini_pred).sum())



            gather_pred=torch.cat(pred,dim=1)

            mioucallback.on_validation_step(batch, points_iou_occ_mini, gather_pred)
           
            if not args.benchmark and ((args.visualize_interval>0 and (i+1)%args.visualize_interval==0) or i+1==len(dataloader)):
                #for callback in callbacks:
                #    if callback.metric_name =='MeanIoU':
                #        callback.on_visualize_eval(batch , model, args)
                mioucallback.on_visualize_eval(batch , model, args)
            


torch.backends.cudnn.enabled=True
torch.backends.cudnn.benchmark=True
           
if __name__ == '__main__':
    from se3_transformer.runtime.callbacks import QM9MetricCallback, PerformanceCallback
    from se3_transformer.runtime.utils import init_distributed, seed_everything
    from se3_transformer.model import IOSE3Transformer, IOSE3Transformer, Fiber
    from se3_transformer.data_loading import ShapeNetModule, QM9DataModule
    import logging
    import sys

    args = PARSER.parse_args()
    envir.init(args)
    if envir.arg.envir=='cluster':
        #global dist
        import torch.distributed as dist
        #global DistributedDataParallel
        from torch.nn.parallel import DistributedDataParallel
    else:
        #global dist
        import smdistributed.dataparallel.torch.distributed as dist
        #global  DistributedDataParallel
        from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel
    is_distributed = init_distributed()
    #print("Is distributed:", is_distributed)
    local_rank = get_local_rank()
    print("local rank:", local_rank)

    #update cfg from args
    CFG = config.load_config(args.config,None)
    cfg_update(CFG,vars(args))

    if envir.arg.envir=='aws':
        CFG['data']['path']=os.environ['SM_CHANNEL_TRAINING']
    #if local_rank ==0:
    #    print("CFG:", CFG)
    #    print("args:", vars(args))
    #    import sys
    #    sys.exit(0)
    rank=get_rank()
    print("rank:",rank)

    if envir.arg.envir=='aws':
        torch.cuda.set_device(local_rank)

    logging.getLogger().setLevel(logging.CRITICAL if local_rank != 0 or args.silent else logging.INFO)

    logging.info('====== SE(3)-Transformer ======')
    logging.info('|  Inference on the test set  |')
    logging.info('===============================')

    if not args.benchmark and args.load_ckpt_path is None:
        logging.error('No load_ckpt_path provided, you need to provide a saved model to evaluate')
        sys.exit(1)

    if args.benchmark:
        logging.info('Running benchmark mode with one warmup pass')

    if args.seed is not None:
        seed_everything(args.seed)

    if args.base_dir:
        if not os.path.exists(args.base_dir):
            args.base_dir.mkdir(parents=True,exist_ok=True)
        args.log_dir = Path(os.path.join(args.base_dir,args.log_dir,"")) if args.log_dir is not None else None
        args.vis_dir = Path(os.path.join(args.base_dir,args.vis_dir,"")) if args.vis_dir is not None else None
        args.save_ckpt_path = Path(os.path.join(args.base_dir,args.save_ckpt_path)) if args.save_ckpt_path is not None else None
        args.load_ckpt_path = Path(os.path.join(args.base_dir,args.load_ckpt_path)) if args.load_ckpt_path is not None else None

    major_cc, minor_cc = torch.cuda.get_device_capability()

    #loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)]
    loggers=[]
    if args.wandb:
        loggers.append(WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer'))
    if args.tensorboard:
        print("Tensorboard")
        loggers.append(TensorBoardLogger(name='Tesnorboard', save_dir=args.log_dir))
    
    logger = LoggerCollection(loggers)
     
    if local_rank == 0:
        print("''''''''''''''''''''''''''DATA MODULE FINISHED'''''''''''''''''''''''")
        print(" ''''''''''''''''''''''''' Parameters '''''''''''''''''''''''''''''''")
        for arg in vars(args):
            print(arg,getattr(args,arg))
        print("''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''")
        print(".................CFG.......................")
        print(CFG)
    
    datamodule = ShapeNetModule(cfg=CFG,**vars(args))

    #import inspect
    #print(inspect.getmembers(datamodule, lambda a:not(inspect.isroutine(a))))
    #args.pooling = None
    model = IOSE3Transformer(
        i_fiber_in = Fiber({1: datamodule.NODE_FEATURE_DIM}),
        i_fiber_out = Fiber({1: args.num_degrees * args.num_channels}),
        o_fiber_in = Fiber({1: datamodule.NODE_FEATURE_DIM}),
        o_fiber_out = Fiber({0: 32}),
        i_fiber_edge = Fiber({}),
        o_fiber_edge = Fiber({}),
        i_num_degrees = args.num_degrees,
        i_num_channels = args.num_channels,
        o_num_degrees = args.num_degrees,
        o_num_channels = args.num_channels,
        tensor_cores = using_tensor_cores(args.amp),  # use Tensor Cores more effectively
        output_dim=1,
        **vars(args)
    )
    
    #callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='test')]
    callbacks=[]

    model.to(device=torch.cuda.current_device())
    if args.load_ckpt_path is not None:
        checkpoint = torch.load(str(args.load_ckpt_path), map_location={'cuda:0': f'cuda:{local_rank}'})
        model.load_state_dict(checkpoint['state_dict'])

    
    if is_distributed and envir.arg.envir=='cluster':
        gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count())


    #callbacks=[ReconMetric(logger,torch.cuda.current_device(),args,metric_name='MeanIoU'),
    #           ReconMetric(logger,torch.cuda.current_device(),args,metric_name='MeanVIoU')]
    callbacks=[ReconMetric(logger,torch.cuda.current_device(),args,metric_name='MeanIoU/Test')]
    
    if dist.is_initialized():
        if envir.arg.envir =='cluster':
            model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
        else:
            model = DistributedDataParallel(model)
            model.cuda(local_rank)
    #model._set_static_graph()
    model.train()
 
    #test_dataloader = datamodule.val_dataloader() 
    test_dataloader = datamodule.test_dataloader()

    if args.optimizer == 'adam':
        optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
                              weight_decay=args.weight_decay)
    elif args.optimizer == 'lamb':
        optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
                              weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    epoch_start = load_state(model, optimizer, args.load_ckpt_path, callbacks) 
    
    print("Running evaluation on model trained for {} epochs".format(epoch_start))
    
    ''' 
    evaluate_voxels(model,
             test_dataloader,
             callbacks,
             args)
    for callback in callbacks:
        if callback.metric_name=='MeanVIoU':
            callback.on_validation_end(epoch_start)
    '''
    evaluate(model,
             test_dataloader,
             callbacks,'MeanIoU/Test',
             args)

    for callback in callbacks:
        callback.on_validation_end(epoch_start)

