# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#
# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES
# SPDX-License-Identifier: MIT

import logging
import pathlib
from pathlib import Path
from typing import List
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import torch
#import torch.distributed as dist
import torch.nn as nn
import os

from torch.optim import Adam
from apex.optimizers import FusedAdam, FusedLAMB
from torch.nn.modules.loss import _Loss
#from torch.nn.parallel import DistributedDataParallel
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm

#from se3_transformer.data_loading import QM9DataModule
from se3_transformer.model import IOSE3Transformer
from se3_transformer.model.fiber import Fiber
from se3_transformer.runtime import gpu_affinity
from se3_transformer.runtime.arguments import PARSER
from se3_transformer.runtime.callbacks import QM9MetricCallback, QM9LRSchedulerCallback, BaseCallback,PerformanceCallback, ReconMetric
from se3_transformer.runtime.recon_inference import evaluate
from se3_transformer.runtime.loggers import LoggerCollection, DLLogger, WandbLogger, Logger, TensorBoardLogger
from se3_transformer.runtime.utils import to_cuda, get_rank, get_local_rank, init_distributed, seed_everything, \
    using_tensor_cores, increase_l2_fetch_granularity

from se3_transformer.data_loading import ShapeNetModule
from types import SimpleNamespace
from im2mesh.common import compute_iou
import warnings

from se3_transformer.runtime.utils import hasnone,cfg_update

import se3_transformer.runtime.environment as envir
from se3_transformer.runtime.utils import hasnone

from se3_transformer.data_loading.shapenet import _get_relative_pos
import dgl

from im2mesh import config

torch.backends.cudnn.enabled=True
torch.backends.cudnn.benchmark=True

warnings.filterwarnings("ignore")
'''
CFG={'method':'onet',
     'data':{'dataset':'Shapes3D',
             'path':'/Datasets/ShapeNet',
             'classes':None,
             'train_split':'train',
             'val_split':'val',
             'test_split':'test',
             'points_subsample':1024,
             'points_file':'points.npz',
             'voxels_file':None,
             'points_unpackbits':True,
             'input_type':'pointcloud',
             'with_transforms':False,
             'pointcloud_n':300,
             'pointcloud_target_n':1024,
             'pointcloud_noise':0.0,
             'pointcloud_file':'pointcloud.npz',
             'points_iou_file':'points.npz'},
     'model':{'use_camera':False}}
'''
def save_state(model: nn.Module, optimizer: Optimizer, epoch: int, best_iou: int, path: pathlib.Path, callbacks: List[BaseCallback]):
    if envir.arg.envir=='cluster':
        #global dist
        import torch.distributed as dist
        #global DistributedDataParallel 
        from torch.nn.parallel import DistributedDataParallel
    else:
        #global dist
        import smdistributed.dataparallel.torch.distributed as dist
        #global  DistributedDataParallel 
        from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel

    """ Saves model, optimizer and epoch states to path (only once per node) """
    if (envir.arg.envir=='cluster' and get_local_rank() == 0) or (envir.arg.envir=='aws' and get_rank()==0):
        state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
        checkpoint = {
            'state_dict': state_dict,
            'optimizer_state_dict': optimizer.state_dict(),
            'epoch': epoch,
            'best_iou': best_iou
        }

        for callback in callbacks:
            callback.on_checkpoint_save(checkpoint)

        torch.save(checkpoint, str(path))
        if envir.arg.envir=='aws':
            torch.save(checkpoint,'/opt/ml/checkpoints/trained_model.pth')
       
        print("Saving Checkpoint!")
        logging.info(f'Saved checkpoint to {str(path)}')


def load_state(model: nn.Module, optimizer: Optimizer, path: pathlib.Path, callbacks: List[BaseCallback]):
    """ Loads model, optimizer and epoch states from path """
    try:
        checkpoint = torch.load(str(path), map_location={'cuda:0': f'cuda:{get_local_rank()}'})
        if isinstance(model, DistributedDataParallel):
            model.module.load_state_dict(checkpoint['state_dict'])
            print("SUCCESSFULLY LOADED PRETRAINED MODEL")
        else:
            model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        for callback in callbacks:
            callback.on_checkpoint_load(checkpoint)

        logging.info(f'Loaded checkpoint from {str(path)}')
        best_iou=0
        if 'best_iou' in checkpoint.keys():
            best_iou=checkpoint['best_iou']
        return checkpoint['epoch'], best_iou
    except:
        print("UNSUCCESSFUL LOAD. INITILIZING MODEL")
        return 0,0


def train_epoch(model, train_dataloader, val_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks,device, args, best_iou):
    if envir.arg.envir=='cluster':
        #global dist
        import torch.distributed as dist
        #global DistributedDataParallel 
        from torch.nn.parallel import DistributedDataParallel
    else:
        #global dist
        import smdistributed.dataparallel.torch.distributed as dist
        #global  DistributedDataParallel 
        from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel

    losses = []
    ious = []
    accs = []
    total_iter=len(train_dataloader)
    world_size = dist.get_world_size() if dist.is_initialized() else 1
    
    #loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([10.],device=device))

    for i, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader), unit='batch',
                         desc=f'Epoch {epoch_idx}', disable=(args.silent or local_rank != 0)):
        #*inputs, target = to_cuda(batch)
        batch = to_cuda(batch)
        batch['forw_key'] = 'o' #in eval 'iou'
        
        for callback in callbacks:
            callback.on_batch_start()

        #pred = []
        loss = []
        iou = []
        acc = []
        
        with torch.cuda.amp.autocast(enabled=args.amp):

            for ind_mini in range(len(batch['points'])):
                
                #print("i={}, ind_mini={}, device={}, rank={}".format(i,ind_mini,device,local_rank))

                 
                # Create o_graph
                batch_size = batch['inputs'].shape[0]
                pcld_size = batch['inputs'].shape[1]
                qpoints_size = batch['points'][ind_mini].shape[1]
                
                int_kmode=pcld_size if args.kmode=='full' else int(args.kmode)
                src_neighbors=batch['input_graph'].edges()[0].reshape(-1,int_kmode)

                diffs = batch['inputs'].unsqueeze(1) - batch['points'][ind_mini].unsqueeze(2)  
                
                distance = torch.norm(diffs,dim=3,p=None)  
                #_,knn = torch.topk(distance,args.knnq,dim=2,largest=False)
                _,knn = torch.min(distance,dim=2)
                
                #inds = knn.reshape(-1) + diffs.shape[2] * torch.arange(diffs.shape[0]*diffs.shape[1],device=device)
                #diffs.reshape(-1,3)[inds,None]
               
                
                #src = (knn + (torch.arange(batch_size,device=device)*pcld_size).reshape(-1,1,1)).reshape(-1) 
                src_k1 = (knn + (torch.arange(batch_size,device=device)*pcld_size).reshape(-1,1)).reshape(-1) 
                src=src_neighbors[src_k1].reshape(-1)
                
                relative_src=src-torch.arange(batch_size,device=device).repeat_interleave(int_kmode*qpoints_size)*pcld_size
                inds=relative_src+diffs.shape[2]*torch.arange(diffs.shape[0]*diffs.shape[1],device=device).repeat_interleave(int_kmode)

                dst = torch.arange(qpoints_size*batch_size,device=device).repeat_interleave(int_kmode) + batch_size*pcld_size
                batch['o_graph'] = dgl.graph((src,dst))
                batch['o_graph'].ndata['pos'] = torch.cat((batch['inputs'].reshape(-1,3),batch['points'][ind_mini].reshape(-1,3)))
                batch['o_graph'].edata['rel_pos'] = _get_relative_pos(batch['o_graph'])
                
                #batch['o_feats'] = batch['o_feats_list'][ind_mini]
                #batch['o_feats'] = {'1': diffs.reshape(-1,3)[inds,None]}
                avg_diffs=torch.mean(diffs.reshape(-1,3)[inds].reshape(-1,int_kmode,3),1,keepdim=True)
                batch['o_feats'] = {'1': avg_diffs}
                
                # Forward Passes
                batch['i_feats'] = model(batch,for_flag='i_forw')
                mini_pred = model(batch,for_flag='o_forw')
                
                if hasnone(mini_pred)>0:
                    print("Something nasty happened during training. ")
                    state_dict = model.module.state_dict() if isinstance(model, DistributedDataParallel) else model.state_dict()
                    checkpoint = {'state_dict': state_dict,
                                  'batch': batch,
                                  'prediction': mini_pred,
                                 }
                    torch.save(checkpoint,Path(os.path.join(args.base_dir,'NanTrainCheckpoint.pth')))
                               
                #print("Nans Detected in OFORW:", torch.isnan(mini_pred).sum())

                
                mini_loss = loss_fn(mini_pred, batch['points.occ'][ind_mini].reshape(-1)) / args.accumulate_grad_batches
                

                if local_rank == 0 and ind_mini==0 and i % 20 ==0:
                    print("POINTS OCC GROUND TRUTH:", batch['points.occ'][ind_mini].reshape(-1).mean().item())
                    print("PREDICTIONS:", (mini_pred>0).float().mean().item())
                #torch.distributed.all_reduce(mini_loss)
                

                mini_loss.backward() if not args.amp else grad_scaler.scale(mini_loss).backward()
                
                #pred.append(mini_pred.detach().clone())
                loss.append(mini_loss.item())
 
        
                with torch.no_grad():
                     
                    pred_class = (mini_pred)>0 #mini_pred in logits
                    corr_class = (batch['points.occ'][ind_mini].reshape(-1))>0.5
                    

                    mini_acc = torch.mean(torch.eq(pred_class,corr_class).float())
                    if dist.is_initialized():
                        dist.all_reduce(mini_acc) 
                        mini_acc /= world_size

                    if mini_acc is None:
                        print("Mini acc is None:", local_rank,device)
                    
                    acc.append(mini_acc)
                    
                    mini_iou=compute_iou(torch.sigmoid(mini_pred).reshape(1,-1),batch['points.occ'][ind_mini].reshape(1,-1))
                    if dist.is_initialized():
                        dist.all_reduce(mini_iou)
                        mini_iou /= world_size
                    iou.append(mini_iou.item())
                    
        # gradient accumulation
        if (i + 1) % args.accumulate_grad_batches == 0 or (i + 1) == len(train_dataloader):
            if args.gradient_clip:
                if args.amp:
                    grad_scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.gradient_clip)
            
            #print("Nan Detected in MODEL before step:", hasnone(model.parameters()))

            optimizer.step() if not args.amp else grad_scaler.step(optimizer)
            #print("Nan Detected in MODEL after step:", hasnone(model.parameters()))
            if args.amp:
                grad_scaler.update()
            model.zero_grad(set_to_none=True)

        losses.append(sum(loss)/len(loss))
        ious.append(sum(iou)/len(iou))

        accs.append(sum(acc)/len(acc))
        
        if not args.benchmark and (
                    (args.eval_interval > 0 and (epoch_idx*total_iter+i+1) % args.eval_interval == 0) ):
            evaluate(model, val_dataloader, callbacks,'MeanIoU/Val', args)
            #evaluate_voxels(model, val_dataloader, callbacks, args)
            model.train()

            for callback in callbacks:
                mean_iou = callback.on_validation_end(epoch_idx*total_iter+i)
                if mean_iou > best_iou:
                    best_iou = mean_iou
                    if args.save_ckpt_path is not None and not args.benchmark:
                        print("FOUND BETTER MODEL WITH IOU:{}. SAVING.....".format(best_iou))
                        save_state(model, optimizer, epoch_idx, best_iou, args.save_ckpt_path, callbacks)


        if i % 20  ==0 and local_rank==0:
            print("Train Loss:{}".format(sum(losses)/len(losses)))
            print("Train IOU:{}".format(ious[-1])) 
            print("Train Acc:{}".format(sum(accs)/len(accs)))
        
    return np.mean(losses), best_iou


def train(model: nn.Module,
          loss_fn: _Loss,
          train_dataloader: DataLoader,
          val_dataloader: DataLoader,
          callbacks: List[BaseCallback],
          logger: Logger,
          args):
    if envir.arg.envir=='cluster':
        #global dist
        import torch.distributed as dist
        #global DistributedDataParallel 
        from torch.nn.parallel import DistributedDataParallel
    else:
        #global dist
        import smdistributed.dataparallel.torch.distributed as dist
        #global  DistributedDataParallel 
        from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel
    
    device = torch.cuda.current_device()

    local_rank= get_local_rank()
  
    model=model.to(device=device)
    if dist.is_initialized():
        if envir.arg.envir=='cluster':
            model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)
        else:
            model = DistributedDataParallel(model)
            model.cuda(local_rank)

    print("Local rank:", local_rank)
    world_size = dist.get_world_size() if dist.is_initialized() else 1

    model.train()
    grad_scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
    if args.optimizer == 'adam':
        optimizer = FusedAdam(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
                              weight_decay=args.weight_decay)
    elif args.optimizer == 'lamb':
        optimizer = FusedLAMB(model.parameters(), lr=args.learning_rate, betas=(args.momentum, 0.999),
                              weight_decay=args.weight_decay)
    else:
        optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum,
                                    weight_decay=args.weight_decay)

    epoch_start,best_iou = load_state(model, optimizer, args.load_ckpt_path, callbacks) if args.load_ckpt_path else 0
    scheduler=torch.optim.lr_scheduler.OneCycleLR(optimizer,max_lr=args.learning_rate,total_steps=args.epochs,pct_start=0.001, anneal_strategy='linear',cycle_momentum=False,final_div_factor=10000,last_epoch=epoch_start-1)
    print("Starting from epoch:",epoch_start)


    for callback in callbacks:
        callback.on_fit_start(optimizer, args)
    
    for epoch_idx in range(epoch_start, args.epochs):
        if isinstance(train_dataloader.sampler, DistributedSampler):
            train_dataloader.sampler.set_epoch(epoch_idx)

        loss,best_iou = train_epoch(model, train_dataloader, val_dataloader, loss_fn, epoch_idx, grad_scaler, optimizer, local_rank, callbacks,device,args,best_iou)
        if dist.is_initialized():
            loss = torch.tensor(loss, dtype=torch.float, device=device)
            dist.all_reduce(loss)
            loss = (loss / world_size).item()
        
        logging.info(f'Train loss: {loss}')
        logger.log_metrics({'train loss': loss}, epoch_idx)
        print("Train Loss:{}......".format(loss))

        for callback in callbacks:
            callback.on_epoch_end()

        #if not args.benchmark and args.save_ckpt_path is not None and args.ckpt_interval > 0 \
        #        and (epoch_idx + 1) % args.ckpt_interval == 0:
        #    save_state(model, optimizer, epoch_idx, args.save_ckpt_path, callbacks)
        '''
        if not args.benchmark and (
                (args.eval_interval > 0 and (epoch_idx + 1) % args.eval_interval == 0) or epoch_idx + 1 == args.epochs):
            evaluate(model, val_dataloader, callbacks, args)
            model.train()

            for callback in callbacks:
                callback.on_validation_end(epoch_idx)
        '''
        scheduler.step()
    if args.save_ckpt_path is not None and not args.benchmark:
        save_state(model, optimizer, args.epochs, args.save_ckpt_path, callbacks)

    for callback in callbacks:
        callback.on_fit_end()


def print_parameters_count(model):
    num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    logging.info(f'Number of trainable parameters: {num_params_trainable}')


if __name__ == '__main__':
    args = PARSER.parse_args()
    envir.init(args)
    if envir.arg.envir=='cluster':
        #global dist
        import torch.distributed as dist
        #global DistributedDataParallel 
        from torch.nn.parallel import DistributedDataParallel
    else:
        #global dist
        import smdistributed.dataparallel.torch.distributed as dist
        #global  DistributedDataParallel 
        from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel
    is_distributed = init_distributed()
    #print("Is distributed:", is_distributed)
    local_rank = get_local_rank()
    print("local rank:", local_rank)
    
    #update cfg from args 
    CFG = config.load_config(args.config,None)
    cfg_update(CFG,vars(args))
    if envir.arg.envir=='aws':
        CFG['data']['path']=os.environ['SM_CHANNEL_TRAINING']
    #if local_rank ==0:
    #    print("CFG:", CFG)
    #    print("args:", vars(args))
    #    import sys
    #    sys.exit(0)
    rank=get_rank()
    print("rank:",rank)
    
    if envir.arg.envir=='aws':
        torch.cuda.set_device(local_rank)

    logging.getLogger().setLevel(logging.CRITICAL if rank != 0 or args.silent else logging.INFO)

    logging.info('====== SE(3)-Transformer ======')
    logging.info('|      Training procedure     |')
    logging.info('===============================')

    
    if args.seed is not None:
        logging.info(f'Using seed {args.seed}')
        seed_everything(args.seed)

    
    # Make possible relative paths absolute. If already absolute paths base_dir ignored. 
    if args.base_dir:
        print("----------------------------------")
        
        if not os.path.exists(args.base_dir):
            args.base_dir.mkdir(parents=True,exist_ok=True)
        print("Check point directory", os.listdir(args.base_dir))
        
        args.log_dir = Path(os.path.join(args.base_dir,args.log_dir,"")) if args.log_dir is not None else None
        args.vis_dir = Path(os.path.join(args.base_dir,args.vis_dir,"")) if args.vis_dir is not None else None
        if envir.arg.envir=='cluster':
            args.save_ckpt_path = Path(os.path.join(args.base_dir,args.save_ckpt_path)) if args.save_ckpt_path is not None else None
        else:
            args.save_ckpt_path = Path(os.path.join(os.environ["SM_MODEL_DIR"],args.save_ckpt_path)) if args.save_ckpt_path is not None else None

        args.load_ckpt_path = Path(os.path.join(args.base_dir,args.load_ckpt_path)) if args.load_ckpt_path is not None else None

    #loggers = [DLLogger(save_dir=args.log_dir, filename=args.dllogger_name)]
    loggers=[]
    if args.wandb:
        print("WANDB")
        loggers.append(WandbLogger(name=f'QM9({args.task})', save_dir=args.log_dir, project='se3-transformer'))
    if args.tensorboard:
        print("Tensorboard")
        loggers.append(TensorBoardLogger(name='Tesnorboard', save_dir=args.log_dir))

    logger = LoggerCollection(loggers)

    #datamodule = QM9DataModule(**vars(args)) 
    datamodule = ShapeNetModule(cfg=CFG,**vars(args))
    
    if local_rank == 0: 
        print("''''''''''''''''''''''''''DATA MODULE FINISHED'''''''''''''''''''''''") 
        print(" ''''''''''''''''''''''''' Parameters '''''''''''''''''''''''''''''''")
        
        argsfile = str(args.base_dir) + '/args.txt'
        f = open(argsfile,'wt')
        for arg in vars(args):
            print(arg,getattr(args,arg))
            f.write(str(arg)+":"+str(getattr(args,arg))+'\n')
        f.close()
        print("''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''''")
        print(".................CFG.......................")
        print(CFG)
        #for cfgk,cfgv in CFG.items():
            #print(cfgk,getattr(cfgk,cfgv))

    #import inspect
    #print(inspect.getmembers(datamodule, lambda a:not(inspect.isroutine(a))))
    #args.pooling = None
    model = IOSE3Transformer(
        i_fiber_in = Fiber({1: datamodule.NODE_FEATURE_DIM}),
        i_fiber_out = Fiber({1: args.num_degrees * args.num_channels}),
        o_fiber_in = Fiber({1: datamodule.NODE_FEATURE_DIM}),
        o_fiber_out = Fiber({0: 32}),
        i_fiber_edge = Fiber({}),
        o_fiber_edge = Fiber({}),
        i_num_degrees = args.num_degrees,
        i_num_channels = args.num_channels,
        o_num_degrees = args.num_degrees,
        o_num_channels = args.num_channels,
        tensor_cores = using_tensor_cores(args.amp),  # use Tensor Cores more effectively
        output_dim=1,
        **vars(args)
    )
    
    loss_fn = nn.BCEWithLogitsLoss()
    '''
    if args.benchmark:
        logging.info('Running benchmark mode')
        world_size = dist.get_world_size() if dist.is_initialized() else 1
        callbacks = [PerformanceCallback(logger, args.batch_size * world_size)]
    else:
        callbacks = [QM9MetricCallback(logger, targets_std=datamodule.targets_std, prefix='validation'),
                     QM9LRSchedulerCallback(logger, epochs=args.epochs)]
    '''
    if is_distributed and envir.arg.envir=='cluster':
        gpu_affinity.set_affinity(gpu_id=get_local_rank(), nproc_per_node=torch.cuda.device_count())
    

    #callbacks=[ReconMetric(logger,torch.cuda.current_device(),args,metric_name='MeanIoU'),
    #           ReconMetric(logger,torch.cuda.current_device(),args,metric_name='MeanVIoU')]
    callbacks = [ReconMetric(logger,torch.cuda.current_device(),args,metric_name='MeanIoU/Val')]

    print_parameters_count(model)
    #logger.log_hyperparams(vars(args))
    #increase_l2_fetch_granularity()
    train(model,
          loss_fn,
          datamodule.train_dataloader(),
          datamodule.val_dataloader(),
          callbacks,
          logger,
          args)

    logging.info('Training finished successfully')
