import torch
torch.multiprocessing.set_sharing_strategy("file_system")
# from torchvision.models import segmentation
# import timm
import torch.nn as nn
import sys
import copy
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.autograd import Variable
from einops import rearrange
import warnings
import scipy.spatial		
from tqdm import tqdm	
import pickle as pkl
# from vit.vit_pytorch.vit import ViT, MViT
import torchvision
# https://htmlcolorcodes.com/color-names/
from colors import color2rgb, color2rgba


import torchvision.transforms as transforms

import argparse
import numpy as np
import time
import os

# from models.mvcnn import *
from util import *
from ops import *
from models.mvtn import *
from models.multi_view import *
from models.voint import *
from models.renderer import *
# from models.segformer import segformer
# from models.swinsegment import  swinsegment


# from logger import Logger
from torch.utils.tensorboard import SummaryWriter
from custom_dataset import ModelNet40, collate_fn, ShapeNetCore, ScanObjectNN, PartNormalDataset  # , ModelNet40


PLOT_SAMPLE_NBS_SEG = [222,357,1402, 1984, 1057, 2201, 1355, 1875,368]
PLOT_SAMPLE_NBS_CLS = [222,330,110]




# colors used in visualization of the segmentation 
color_label_list = ["red", "green", "blue", "brown", "purple","orange", "yellow", "black", "red3", "green3", "blue3"]
COLOR_LABEL_VALUES_LIST = [color2rgb(clr) for clr in color_label_list]


parser = argparse.ArgumentParser(description='MVCNN-PyTorch')
parser.add_argument('--run_mode', '-rmode',  default="train_cls", choices=["train_cls", "train_part", "train_ssl", "train_scene", "train_city", "test_cls", "train_pls", "test_pls", "test_part", "train_score", "test_retr", "test_scene", "test_city", "test_rot", "test_occ", "test_score", "train_cls_retr", "train_point", "test_point", "test_views"],
                    help='The mode of running the code: train, test classification, test retrieval, test rotation robustness, or test occlusion robustness, or part segmentation, or scene segmentation, or city segemtnation. You have to train before testing')
parser.add_argument('--data_dir', required=True,  help='path to 3D dataset')
parser.add_argument('--mvnetwork', '-m',  default='resnet', choices=['resnet','mvcnn', "vit", "mvit", "wvit","nvit","uvit","viewgcn","segformer","fcn","deeplab","swinsegment","clsdeeplab","PointNet","DGCNN"],
                    help='pretrained mvnetwork: ' + ' | '.join(['resnet', 'mvcnn', "vit", "mvit"]) + ' (default: {})'.format('resnet'))
parser.add_argument('--exp_set', type=str, default='00', help='pick ')
parser.add_argument('--exp_id', type=str, default='random', help='pick ')
parser.add_argument('--nb_views', default=4, type=int,
                    help='number of views in MV CNN')
parser.add_argument('--image_size', default=224, type=int,
                    help='the size of the images rendered by the differntibe renderer ( other poissible 384)')
parser.add_argument('--gpu', type=int,
                     default=0, help='GPU number ')
parser.add_argument('--epochs', default=100, type=int,  help='number of total epochs to run (default: 100)')
parser.add_argument('-b', '--batch_size', default=20, type=int,
                     help='mini-batch size (default: 4)')
parser.add_argument('--config_file', '-cfg',  default="config.yaml",
                    help='the conifg yaml file for more options.')

# parser.add_argument('--image_data',required=False,  help='path to 2D dataset')
parser.add_argument('--canonical_elevation', default=30.0, type=float,
                     help='if views_config== canoncal , the elevation of the view points is givene by this angle')
parser.add_argument('--canonical_distance', default=1.0, type=float,
                     help='the distnace of the view points from the center if the object, defulat 1.0 fpr point cloud and 2.2 for meshes  ')
parser.add_argument('--views_config', '-s',  default="circular", choices=["circular", "random", "learned_offset", "learned_direct", "spherical", "learned_spherical", "learned_random", "learned_transfer", "custom"],
                    help='the selection type of views ')
parser.add_argument('--transform_distance', dest='transform_distance',
                    action='store_true', help='use rnadomized distance to the object')
parser.add_argument('--plot_freq', default=50, type=int, 
                    help='the frequqency of plotting the renderings and camera positions')
parser.add_argument('--depth', type=int,  default=2,
                    help='resnet depth (default: resnet18)')
parser.add_argument('--input_view_noise', default=0.0, type=float,
                    help='the variance of the gaussian noise (before normalization with parametre range) added to the azim,elev,dist inputs to the MVTN ... this option is valid only if `learned_offset` or `learned_direct` options are sleected   ')


parser.add_argument('--lr', '--learning_rate', default=0.00001, type=float,
                     help='initial learning rate (default: 0.0001)')
parser.add_argument('--weight_decay', default=0.3, type=float,
                    help='weight decay for MVT ... default 0.01')
parser.add_argument('-r', '--resume', dest='resume',
                    action='store_true', help='continue training from the `setup[weights_file] checkpoint ')
parser.add_argument('--save_all', dest='save_all',
                    action='store_true', help='save save the checkpoint and results at every epoch.... default saves only best test accuracy epoch')
parser.add_argument('--lr_decay_freq', default=30, type=float,
                    help='learning rate decay (default: 30)')
parser.add_argument('--lr_decay', default=0.1, type=float,
                    help='learning rate decay (default: 0.1)')
## point cloud rnedienring 
# parser.add_argument('--pc_rendering', dest='pc_rendering',
#                     action='store_true', help='use point cloud renderer instead of mesh renderer  ')
parser.add_argument('--points_radius', default=0.006, type=float,
                    help='the size of the rendered points if `pc_rendering` is True  ')
parser.add_argument('--points_per_pixel',  default=1, type=int,
                     help='max number of points in every rendered pixel if `pc_rendering` is True ')
parser.add_argument('--dset_variant', '-dsetp',  default="obj_only", choices=["obj_only", "with_bg", "hardest"])
parser.add_argument('--nb_points', default=2048, type=int,help='number of points in the 3d dataeset sampled from the meshes ')
parser.add_argument('--object_color', '-clr',  default="white", choices=["white", "random","black","red","green","blue", "learned","custom"],
                    help='the selection type of views ')
parser.add_argument('--background_color', '-bgc',  default="white", choices=["white", "random","black","red","green","blue", "learned","custom"],
                    help='the color of the background of the rendered images')                
parser.add_argument('--augment_training', dest='augment_training',
                    action='store_true', help='augment the training of the CNN by scaling , rotation , translation , etc ')
parser.add_argument('--rotated_test', dest='rotated_test',
                    action='store_true', help=' test on rotation noise on the meshes from ModelNet40 to make it realistic  ')
parser.add_argument('--rotated_train', dest='rotated_train',
                    action='store_true', help=' train on rotation noise on the meshes from ModelNet40 to make it realistic  ')

#### MV Transfoemer       
parser.add_argument('--patch_size', default=16, type=int,help='patch size of the MVIT ')
# parser.add_argument('--feat_dim', default=768, type=int,
#                     help='te dimension of the multi-view transfomerer features')
parser.add_argument('--mlp_dim', default=2048, type=int,help='te dimension of the MLP for classification')
parser.add_argument('--mvit_heads', default=12, type=int,help='te numbner of heads in multihead self-attnetion in the mvit')
parser.add_argument('--mvit_dropout', default=0.0, type=float,help='the dropit at the main mvit  ')
parser.add_argument('--emb_dropout', default=0.1, type=float,help='the dropit at the embedding of the mvit  ')
parser.add_argument('--mv_agr_type', '-mvaggr',  default="max",
                    choices=["max", "mean"], help='pool type of the multi-view setup')
parser.add_argument('--vit_agr_type', '-vitaggr',  default="cls",
                    choices=["mean", "cls"], help='pool type must be either cls (cls token) or mean (mean pooling)')
parser.add_argument('--nb_windows', default=1, type=int,
                    help='the number of windows if `mvnetwork` == `wvit`. if it is 1 it collapses to mvit , if it is = nb_views, it collapses to vit ')
parser.add_argument('--vit_variant',  default="vit",
                    choices=["vit", "swin", "vit_deit"], help='the type of the vision transformer used: vanillar vit or swin or ...')
parser.add_argument('--vit_model_size',  default="base",
                    choices=["base", "small", "tiny","large","huge"], help='the type of the vision transformer used: vanillar vit or swin or ...')
parser.add_argument('--swin_window_size', default=7, type=int,
                    help='window size of Swin transformer (if `vit_variant` == `swin`) ')
# parser.add_argument('--pretrained_21k', dest='pretrained_21k',
#                     action='store_true', help='use the pretrained weights on ImageNet 22K if available , else the regular 1K imageNEt  ')
# parser.add_argument('--pretrained', dest='pretrained',
#                     action='store_true', help='use pre-trained 2D network')
parser.add_argument('--pretraining_mode', '-prmode',  default="imagenet", choices=["imagenet", "imagenet21k", "ssl", "fsl", "scratch","coco","modelnet" ],
                    help='The mode of pretraining the 2D network used in the pipeline ssl:Self-supervised learning , fsl:Fully-supervised learning ')
# part Segmentation 
# parser.add_argument('--post_process_iters', default=1, type=int,
#                     help='the number of post processing iterations with nearest neighbor propogation , if 0 : no post processing in evaluation ')
# parser.add_argument('--post_process_k', default=10, type=int,
#                     help='the number of K neightbor used in post processing grows with power of k every iteration  if 0 : no post processing in evaluation ')
# parser.add_argument('--parallel_head', dest='parallel_head',
#                     action='store_true', help='do segmntation as parallel  heads whwere each head is focused on one class ')
parser.add_argument('--lifting_method',  default="mode",
                    choices=["mode", "mlp", "mean","view_attention","pixel_attention","point_attention","max","gcn","transformer","gat"], help='the type of operation used to lift the 2d predictions to 3d predictions')
parser.add_argument('--record_extra_metrics', dest='record_extra_metrics',
                    action='store_true', help='record_extra_metrics like the percentage of every class in points/pixels to its IOU')
# parser.add_argument('--balanced_object_loss', dest='balanced_object_loss',
                    # action='store_true', help='do focal loss on the part segmentation based on the class frequency')
parser.add_argument('--clip_grads', dest='clip_grads',
                    action='store_true', help='clip the gradients of the MVTN with L2= `clip_grads_value` ')
parser.add_argument('--clip_grads_value', default=50.0, type=float,
                    help='the clip value for L2 of gradeients of the MVTN ')
parser.add_argument('--lambda_l2d', default=1.0, type=float,
                    help='the 2D CE loss coefficient on the segmentation pipeline  ')
parser.add_argument('--lambda_l3d', default=0.003, type=float,
                    help='the 3D CE loss coefficient on the segmentation pipeline')
parser.add_argument('--use_mlp_classifier', dest='use_mlp_classifier',
                    action='store_true', help='use shared mlp for segmentation after getting the 2D features ')
parser.add_argument('--use_xyz', dest='use_xyz',
                    action='store_true', help='use xyz appended to lifted features. only if `use_mlp_classifier` == True ')
parser.add_argument('--use_global_feats', dest='use_global_feats',
                    action='store_true', help='use use global features by max pooling in segmentation. only if `use_mlp_classifier` == True ')
parser.add_argument('--freeze_2d_net', dest='freeze_2d_net',
                    action='store_true', help='freeze the segmentation 2D network and only trin the 3d part' )
parser.add_argument('--mlp_learning_rate', default=0.001, type=float,
                     help='initial learning rate (default: 0.0001)')
parser.add_argument('--balanced_2d_loss_alpha', default=0.0, type=float,
                    help='do focal loss on the part segmentation based on the 2D label class frequency. when alpha=0.0 , no balance happens')
parser.add_argument('--balanced_3d_loss_alpha', default=0.0, type=float,
                    help='do focal loss on the part segmentation based on the 3D label class frequency. when alpha=0.0 , no balance happens')
parser.add_argument('--extra_net',  default="none",
                    choices=["none", "PointNet", "DGCNN",], help=' the last point network used in the segmentation pipeline (optional) . defulat = none ')



# parser.add_argument('--run_full_eval', dest='run_full_eval',
#                     action='store_true', help='run the full evaluation with fixed spherical views' )
# parser.add_argument('--full_eval_views', default=30, type=int,
#                     help='the number of views in test if `run_full_eval` == True ')

# parser.add_argument('--mixed_precision', dest='mixed_precision',
#                     action='store_true', help='use mixed precision operations: WARNING: it can destroy the training but make inference faster ' )
# voints 
parser.add_argument('--use_view_info', dest='use_view_info',
                    action='store_true', help='use the view infromation (azim, elev) in learning the voints ')
parser.add_argument('--view_embedding_dim', default=24, type=int,
                    help='the dimension of embedding the views if `use_view_info` == True ')
parser.add_argument('--view_embeddgin_type',  default="none",
                    choices=["none", "zeros", "sin", "learned","fourier"], help='the type view embedding used in voint learning if `use_view_info` == True ')
parser.add_argument('--voint_aggr',  default="max",
                    choices=["max", "mean"], help='the type aggregation used to convert voints to points ')
parser.add_argument('--use_cls_voint', dest='use_cls_voint',
                    action='store_true', help='use the aggregated voints (classification virtual voint) and append it to all voints feats ')
parser.add_argument('--feat_dim', default=64, type=int,
                    help='te dimension of the multi-view transfomerer features')
parser.add_argument('--leanred_cls_token', dest='leanred_cls_token',
                    action='store_true', help='use a learned cls view token instead of zeros as intilization when `lifting_method` in [`gcn`,`transformer`, `gat`] ')
parser.add_argument('--use_voint_xyz', dest='use_voint_xyz',
                    action='store_true', help='use xyz appended to lifted features. only if `lifting_method` in [`mlp`,`gcn`,`gat`] ')
parser.add_argument('--voint_depth', type=int,  default=4,
                    help='resnet depth (default: resnet18)')
parser.add_argument('--use_early_voint_feats', dest='use_early_voint_feats',
                    action='store_true', help='use use early voint segmentaiton features instead of the logits')
parser.add_argument('--voint_out_size', default=64, type=int,
                    help='the dimension of output of the vointnets , default = 7 (nb_parts+1) . it is also the in size of the mlp classifier if `use_mlp_classifier` == True ')

# scene segmentation 
def reinitilize_setup(setup, models_bag,new_nb_views=24,views_config="spherical"):
    setup["nb_views"] = new_nb_views
    load_checkpoint(setup, models_bag,setup["weights_file"], ignore_optimizer=True)
    models_bag["mvtn"] = MVTN(setup["nb_views"], views_config=views_config,
                              canonical_elevation=setup["canonical_elevation"], canonical_distance=setup["canonical_distance"],
                              shape_features_size=setup["features_size"], transform_distance=setup["transform_distance"], input_view_noise=setup["input_view_noise"], shape_extractor=setup["shape_extractor"], screatch_feature_extractor=["screatch_feature_extractor"]).cuda()
    models_bag["mvrenderer"] = MVRenderer(nb_views=setup["nb_views"], image_size=setup["image_size"], pc_rendering=setup["pc_rendering"], object_color=setup["object_color"], background_color=setup["background_color"],
                                          faces_per_pixel=setup["faces_per_pixel"], points_radius=setup["points_radius"],  points_per_pixel=setup["points_per_pixel"], light_direction=setup["light_direction"], cull_backfaces=setup["cull_backfaces"]).cuda()

    return setup, models_bag

def reduce_part_seg_prediction(outputs,parts_nb,parts_range):
    bs, nb_classes,nb_points = outputs.shape
    reduced_prediction = torch.zeros((bs, nb_points),device=outputs.device,dtype=torch.long)
    for b in range(bs):
        _, reduced_prediction[b] = torch.max(outputs[b, parts_range[b]:parts_range[b]+parts_nb[b], :].data,dim=0)
    return reduced_prediction

args = parser.parse_args()
args = vars(args)
config = read_yaml(args["config_file"],flatten=True, ignore_hierarchy=True)
setup = {**args, **config}
initialize_setup(setup)

print('Loading data')

transform = None
scaler  = torch.cuda.amp.GradScaler() # for mixed precision 
# a function to preprocess pytorch3d Mesh onject

# device = torch.device("cuda:{}".format(str(setup["gpu"])) if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(int(setup["gpu"]))
if "modelnet" in setup["data_dir"].lower():
    dset_train = ModelNet40(setup["data_dir"], "train", nb_points=setup["nb_points"], simplified_mesh=setup["simplified_mesh"], cleaned_mesh=setup["cleaned_mesh"], dset_norm=setup["dset_norm"], return_points_saved=setup["return_points_saved"],
                            is_rotated=setup["rotated_train"])
    dset_val = ModelNet40(setup["data_dir"], "test", nb_points=setup["nb_points"], simplified_mesh=setup["simplified_mesh"], cleaned_mesh=setup["cleaned_mesh"], dset_norm=setup["dset_norm"], return_points_saved=setup["return_points_saved"],
                          is_rotated=setup["rotated_test"])
    classes = dset_train.classes
    parts_per_class = [1] * len(classes)


elif "shapenetcore" in setup["data_dir"].lower():
    dset_train = ShapeNetCore(setup["data_dir"],("train",), setup["nb_points"], load_textures=False, dset_norm=setup["dset_norm"],simplified_mesh=setup["simplified_mesh"])
    dset_val = ShapeNetCore(setup["data_dir"],("test",), setup["nb_points"], load_textures=False, dset_norm=setup["dset_norm"],simplified_mesh=setup["simplified_mesh"])
    # dset_train, dset_val = torch.utils.data.random_split(shapenet, [int(.8*len(shapenet)), int(np.ceil(0.2*len(shapenet)))])  #, generator=torch.Generator().manual_seed(42))   ## to reprodebel results 
    classes = dset_val.classes
    parts_per_class = [1] * len(classes)

elif "scanobjectnn" in setup["data_dir"].lower():
    dset_train = ScanObjectNN(setup["data_dir"], 'train',  setup["nb_points"],
                              variant=setup["dset_variant"], dset_norm=setup["dset_norm"])
    dset_val = ScanObjectNN(setup["data_dir"], 'test',  setup["nb_points"], variant=setup["dset_variant"], dset_norm=setup["dset_norm"])
    classes = dset_train.classes
    parts_per_class = [1] * len(classes)

elif "part" in setup["data_dir"].lower():
    dset_train = PartNormalDataset(root=setup["data_dir"], npoints=setup["nb_points"], split='trainval',class_choice=None, normal_channel=setup["use_normals"], is_rotated=setup["rotated_train"])
    dset_val = PartNormalDataset(root=setup["data_dir"], npoints=setup["nb_points"], split='test',
                                 class_choice=None, normal_channel=setup["use_normals"], is_rotated=setup["rotated_test"])
    parts_per_class = dset_train.parts_per_class
    classes = sorted(list(dset_train.seg_classes.keys()))

train_loader = DataLoader(dset_train, batch_size=setup["batch_size"],
                          shuffle=True, num_workers=4, collate_fn=collate_fn, drop_last=True)

val_loader = DataLoader(dset_val, batch_size=int(setup["batch_size"]),
                        shuffle=False, num_workers=4, collate_fn=collate_fn)

print("classes nb:", len(classes), "number of train models: ", len(
    dset_train), "number of test models: ", len(dset_val), classes)

mvnetwork = get_mvnetwork(setup, num_classes=len(classes), num_parts=max(parts_per_class))
mvnetwork.cuda()
lifting_net = get_liftingnet(setup, num_classes=len(classes), num_parts=max(parts_per_class))
cudnn.benchmark = True

print('Running on ' + str(torch.cuda.current_device()))


# Loss and Optimizer
lr = setup["lr"]
lr_mlp = setup["mlp_learning_rate"]
n_epochs = setup["epochs"]
mvtn = MVTN(setup["nb_views"], views_config=setup["views_config"],
                             canonical_elevation=setup["canonical_elevation"],canonical_distance= setup["canonical_distance"],
            shape_features_size=setup["features_size"], transform_distance=setup["transform_distance"], input_view_noise=setup["input_view_noise"], shape_extractor=setup["shape_extractor"], screatch_feature_extractor=["screatch_feature_extractor"]).cuda()
mvrenderer = MVRenderer(nb_views=setup["nb_views"], image_size=setup["image_size"], pc_rendering=setup["pc_rendering"], object_color=setup["object_color"], background_color=setup["background_color"],
                        faces_per_pixel=setup["faces_per_pixel"], points_radius=setup["points_radius"],  points_per_pixel=setup["points_per_pixel"], light_direction=setup["light_direction"], cull_backfaces=setup["cull_backfaces"]).cuda()
mlp_classifier = get_mlp_classifier(setup, num_classes=len(classes), num_parts=max(parts_per_class))
mvlifting = MVLiftingModule(image_size=setup["image_size"], lifting_method=setup["lifting_method"], mlp_classifier=mlp_classifier, balanced_object_loss=setup["balanced_object_loss"], balanced_3d_loss_alpha=setup["balanced_3d_loss_alpha"], lifting_net=lifting_net,use_early_voint_feats=setup["use_early_voint_feats"]).cuda()
print(setup)
criterion = nn.CrossEntropyLoss()
views_criterion = nn.CosineSimilarity()
optimizer = torch.optim.AdamW(
    mvnetwork.parameters(), lr=lr, weight_decay=setup["weight_decay"])
if setup["is_learning_views"]:
    mvtn_optimizer = torch.optim.AdamW(mvtn.parameters(), lr=setup["vs_learning_rate"], weight_decay=setup["vs_weight_decay"])
else : 
    mvtn_optimizer = None
if setup["learning_lifting"]:
    mlp_optimizer = torch.optim.AdamW(mvlifting.parameters(), lr=setup["mlp_learning_rate"])
else:
    mlp_optimizer = None

models_bag = {"mvnetwork": mvnetwork,
              "optimizer": optimizer, "mvtn": mvtn, "mvtn_optimizer": mvtn_optimizer, "mvrenderer": mvrenderer, "mvlifting": mvlifting, "mlp_optimizer": mlp_optimizer}



def train(data_loader, models_bag, setup):
    train_size = len(data_loader)
    total = 0.0
    correct = 0.0

    total_loss = 0.0
    n = 0
    # torch.autograd.set_detect_anomaly(True)

    for i, (targets, meshes, points) in enumerate(data_loader):
        c_batch_size =  targets.shape[0]
        # if i > 2 :
        #     continue 

        models_bag["optimizer"].zero_grad()
        if setup["is_learning_views"]:
            models_bag["mvtn_optimizer"].zero_grad()
        # if setup["is_learning_points"]:
        #     models_bag["fe_optimizer"].zero_grad()

        # inputs = np.stack(inputs, axis=1)
        # inputs = torch.from_numpy(inputs)
        with torch.cuda.amp.autocast(bool(setup["mixed_precision"])):
            with torch.cuda.amp.autocast(False):
                azim, elev, dist = models_bag["mvtn"](points, c_batch_size=c_batch_size)
                view_info = torch.cat([azim.unsqueeze(-1), elev.unsqueeze(-1)],dim=-1)
                rendered_images, _ = models_bag["mvrenderer"](meshes, points,  azim=azim, elev=elev, dist=dist)
            rendered_images = regualarize_rendered_views(rendered_images, setup["view_reg"], setup["augment_training"], setup["crop_ratio"])
            targets = targets.cuda()
            targets = Variable(targets)
            outputs = models_bag["mvnetwork"](rendered_images)[0]


            loss = criterion(outputs, targets.to(torch.long))
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)

        if setup["mixed_precision"]:
            scaler.scale(loss).backward()
            scaler.unscale_(models_bag["optimizer"])
            scaler.step(models_bag["optimizer"])
            scaler.update()
        else:
            loss.backward()
            models_bag["optimizer"].step()
        if setup["log_metrics"]:
            # step = get_current_step(models_bag["optimizer"])
            writer.add_scalar('Zoom/loss', loss.item(), i +setup["c_epoch"]*train_size)
            # writer.add_scalar('Zoom/MVCNN_vals', list(models_bag["mvnetwork"].parameters())[0].data[0, 0, 0].item(), step)
            writer.add_scalar('Zoom/MVCNN_grads', np.sum(np.array([np.sum(x.grad.cpu().numpy(
            ) ** 2) for x in models_bag["mvnetwork"].parameters()])), i + setup["c_epoch"]*train_size)



        if (i + 1) % setup["print_freq"] == 0:
            print("\tIter [%d/%d] Loss: %.4f" % (i + 1, train_size, loss.item()))
        correct += (predicted.cpu() == targets.cpu()).sum().item()
        total_loss += loss.detach().item()
        n += 1
    avg_loss = total_loss / n
    avg_train_acc = 100 * correct / total

    return avg_train_acc,avg_loss

def train_part_seg(data_loader, models_bag, setup):
    train_size = len(data_loader)
    correct = 0.0

    total_loss = 0.0
    total = 0
    n = 0
    t = enumerate(data_loader)
    for i, (points, cls, seg, parts_range, parts_nb, _) in t:
        # if i > 1 :
        #     continue

        models_bag["optimizer"].zero_grad()
        if setup["learning_lifting"]:
            models_bag["mlp_optimizer"].zero_grad()


        c_batch_size = points.shape[0]
        colors = []
        if setup["use_normals"]:
            normals = points[:, :, 3:6]
            colors = (normals + 1.0) / 2.0
            colors = colors/torch.norm(colors, dim=-1,p=float(setup["color_normal_p"]))[..., None]
            points = points[:,:,0:3]
        with torch.cuda.amp.autocast(bool(setup["mixed_precision"])):
            with torch.cuda.amp.autocast(False):
                azim, elev, dist = models_bag["mvtn"](points, c_batch_size=c_batch_size)
                view_info = torch.cat([azim.unsqueeze(-1), elev.unsqueeze(-1)], dim=-1)
                rendered_images, indxs, distance_weight_maps, _ = models_bag["mvrenderer"](None, points,  azim=azim, elev=elev, dist=dist, color=colors)
            rendered_images = regualarize_rendered_views(rendered_images, setup["view_reg"], setup["augment_training"], setup["crop_ratio"])
            # rendered_images, indxs, distance_weight_maps , _, _, _ = auto_render_parts(cls, None, points, models_bag, setup,color=colors, )
            cls = cls.cuda()
            cls = Variable(cls)
            seg = seg.cuda()
            points = Variable(points).cuda()
            seg = Variable(seg)
            # if i > 1 :
            #     continue
            # print(torch.unique(seg, False), cls, parts_range, parts_nb,"seg.shape",seg.shape)
            seg = seg + 1 - parts_range[..., None].cuda().to(torch.int) if setup["parallel_head"] else seg + 1
            # parts_range += 1  # the label 0 is reserved for bacgdround


            labels_2d, pix_to_face_mask = models_bag["mvlifting"].compute_image_segment_label_points(
                points, batch_points_labels=seg, rendered_pix_to_point=indxs, )
            labels_2d = Variable(labels_2d)

            rendered_images = Variable(rendered_images)

            criterion2d = nn.CrossEntropyLoss(ignore_index=0, reduction="none" if setup["balanced_object_loss"] else "mean")
            outputs , feats = models_bag["mvnetwork"](rendered_images,cls)
            loss2d = models_bag["mvnetwork"].get_loss(criterion2d, outputs, labels_2d, cls)
            _, predicted = torch.max(outputs.data, dim=1) 
            views_weights = models_bag["mvlifting"].compute_views_weights(azim, elev, rendered_images, normals)
            predictions_3d = models_bag["mvlifting"].lift_2D_to_3D(points, predictions_2d=svctomvc(outputs, nb_views=setup["nb_views"]), rendered_pix_to_point=indxs, views_weights=views_weights, cls=cls, parts_nb=parts_nb, view_info=view_info, early_feats=feats)
            criterion3d = nn.CrossEntropyLoss(ignore_index=0, reduction="none" if setup["balanced_object_loss"] else "mean")
            loss3d = models_bag["mvlifting"].get_loss_3d(criterion3d, predictions_3d, seg, cls)
            loss = setup["lambda_l2d"] * loss2d + setup["lambda_l3d"] * loss3d
            _, predictions_3d = torch.max(predictions_3d, dim=1)

        # if setup["parallel_head"]:
        #     target_mask = torch.arange(0, len(classes))[None, ...].repeat(c_batch_size, 1).cuda() == cls
        #     target = labels_2d.to(torch.long).unsqueeze(2).repeat(1, 1, len(classes), 1, 1) * target_mask[..., None][..., None].unsqueeze(1).to(torch.long)
        #     loss = criterion(outputs, rearrange(target, 'b V cls h w -> (b V) h w cls'))
        #     predict_mask = target_mask.unsqueeze(1).unsqueeze(1).unsqueeze(1).unsqueeze(1).to(torch.float).repeat(1, setup["nb_views"],1, 1, 1, 1)
        #     _, predicted = torch.max(torch.max(outputs.data * rearrange(predict_mask, 'b V C h w cls -> (b V) C h w cls'), dim=-1)[0],dim=1)
        # else:
        #     loss = criterion(outputs, mvtosv(labels_2d.to(torch.long)))
        #     _, predicted = torch.max(outputs.data, dim=1)

        if setup["mixed_precision"]:
            scaler.scale(loss).backward()
            scaler.unscale_(models_bag["optimizer"])
            if setup["learning_lifting"]:
                scaler.unscale_(models_bag["mlp_optimizer"])

        else:
            loss.backward()
        total_loss += loss.detach().item()
        n += 1
        total += pix_to_face_mask.sum().item()
        correct += ((predicted == mvtosv(labels_2d.to(torch.long))) & mvtosv(pix_to_face_mask[:,:,0,...])).sum().item()

        if setup["clip_grads"]:
            clip_grads_(models_bag["mvnetwork"].parameters(), setup["clip_grads_value"])
        if not setup["freeze_2d_net"]:
            if setup["mixed_precision"]:
                scaler.step(models_bag["optimizer"])
                scaler.update()
            else:
                models_bag["optimizer"].step()
        if setup["learning_lifting"]:
            if setup["mixed_precision"]:
                scaler.step(models_bag["mlp_optimizer"])
                scaler.update()
            else:
                models_bag["mlp_optimizer"].step()

        if setup["log_metrics"]:
            # step = get_current_step(models_bag["optimizer"])
            writer.add_scalar('Zoom/loss', loss.item(), i +
                            setup["c_epoch"]*train_size)
            # writer.add_scalar('Zoom/MVCNN_vals', list(models_bag["mvnetwork"].parameters())[0].data[0, 0, 0].item(), step)
            writer.add_scalar('Zoom/MVCNN_grads', np.sum(np.array([np.sum(x.grad.cpu().numpy() ** 2) for x in filter(lambda y: type(y.grad) != type(None), models_bag["mvnetwork"].parameters())])), i + setup["c_epoch"]*train_size)
            writer.add_scalar('Zoom/MVlifting_grads', np.sum(np.array([np.sum(x.grad.cpu().numpy() ** 2) for x in filter(
                lambda y: type(y.grad) != type(None), models_bag["mvlifting"].parameters())])), i + setup["c_epoch"]*train_size)


        if (i + 1) % setup["print_freq"] == 0:
            print("\tIter [%d/%d] Loss: %.4f" %
                (i + 1, train_size, loss.item()))

    avg_loss = total_loss / n
    avg_train_acc = 100 * correct / total

    return avg_train_acc, avg_loss


def train_pls(data_loader, models_bag, setup):
    train_size = len(data_loader)
    total = 0.0
    correct = 0.0

    total_loss = 0.0
    n = 0
    # torch.autograd.set_detect_anomaly(True)

    for i, (targets, meshes, points) in enumerate(data_loader):
        c_batch_size, nb_points,_ = points.shape
        cls = torch.ones((c_batch_size)).cuda().to(torch.int)
        labels_3d = targets[...,None].repeat(1,nb_points).to(torch.long).cuda()
        # if i > 2 :
        #     continue

        models_bag["optimizer"].zero_grad()
        if setup["is_learning_views"]:
            models_bag["mvtn_optimizer"].zero_grad()
        # if setup["is_learning_points"]:
        #     models_bag["fe_optimizer"].zero_grad()

        # inputs = np.stack(inputs, axis=1)
        # inputs = torch.from_numpy(inputs)
        with torch.cuda.amp.autocast(bool(setup["mixed_precision"])):
            with torch.cuda.amp.autocast(False):
                azim, elev, dist = models_bag["mvtn"](
                    points, c_batch_size=c_batch_size)
                view_info = torch.cat([azim.unsqueeze(-1), elev.unsqueeze(-1)], dim=-1)
                rendered_images, indxs, distance_weight_maps, _ = models_bag["mvrenderer"](None, points,  azim=azim, elev=elev, dist=dist, color=None)
            rendered_images = regualarize_rendered_views(rendered_images, setup["view_reg"], setup["augment_training"], setup["crop_ratio"])
            targets = targets.cuda()
            targets = Variable(targets)
            outputs, feats = models_bag["mvnetwork"](rendered_images,cls)
            _,nb_views, _,h ,w = rendered_images.shape

            criterion2d = nn.CrossEntropyLoss()
            loss2d = criterion2d(svctomvc(outputs, nb_views=nb_views).transpose(
                1, 2), targets.to(torch.long)[..., None][..., None][..., None].repeat(1, nb_views,h,w))

            views_weights = models_bag["mvlifting"].compute_views_weights(
                azim, elev, rendered_images, points)
            predictions_3d = models_bag["mvlifting"].lift_2D_to_3D(points.cuda(), predictions_2d=svctomvc(outputs, nb_views=setup["nb_views"]), rendered_pix_to_point=indxs, views_weights=views_weights, cls=cls, parts_nb=len(classes) * cls, view_info=None, early_feats=feats)
            criterion3d = nn.CrossEntropyLoss()
            loss3d = criterion3d(predictions_3d, labels_3d)
            loss = setup["lambda_l2d"] * loss2d + setup["lambda_l3d"] * loss3d
            # _, predictions_3d = torch.max(predictions_3d, dim=1)
            _, predicted = torch.max(torch.mean(predictions_3d, dim=2),dim=-1)


        if setup["mixed_precision"]:
            scaler.scale(loss).backward()
            scaler.unscale_(models_bag["optimizer"])
            if setup["learning_lifting"]:
                scaler.unscale_(models_bag["mlp_optimizer"])

        else:
            loss.backward()
        total_loss += loss.detach().item()
        n += 1
        total += targets.size(0)
        correct += (predicted.cpu() == targets.cpu()).sum().item()

        if setup["clip_grads"]:
            clip_grads_(models_bag["mvnetwork"].parameters(),
                        setup["clip_grads_value"])
        if not setup["freeze_2d_net"]:
            if setup["mixed_precision"]:
                scaler.step(models_bag["optimizer"])
                scaler.update()
            else:
                models_bag["optimizer"].step()
        if setup["learning_lifting"]:
            if setup["mixed_precision"]:
                scaler.step(models_bag["mlp_optimizer"])
                scaler.update()
            else:
                models_bag["mlp_optimizer"].step()

        if setup["log_metrics"]:
            # step = get_current_step(models_bag["optimizer"])
            writer.add_scalar('Zoom/loss', loss.item(), i +
                              setup["c_epoch"]*train_size)
            # writer.add_scalar('Zoom/MVCNN_vals', list(models_bag["mvnetwork"].parameters())[0].data[0, 0, 0].item(), step)
            writer.add_scalar('Zoom/MVCNN_grads', np.sum(np.array([np.sum(x.grad.cpu().numpy() ** 2) for x in filter(
                lambda y: type(y.grad) != type(None), models_bag["mvnetwork"].parameters())])), i + setup["c_epoch"]*train_size)
            writer.add_scalar('Zoom/MVlifting_grads', np.sum(np.array([np.sum(x.grad.cpu().numpy() ** 2) for x in filter(
                lambda y: type(y.grad) != type(None), models_bag["mvlifting"].parameters())])), i + setup["c_epoch"]*train_size)

        if (i + 1) % setup["print_freq"] == 0:
            print("\tIter [%d/%d] Loss: %.4f" %
                  (i + 1, train_size, loss.item()))

    avg_loss = total_loss / n
    avg_train_acc = 100 * correct / total

    return avg_train_acc, avg_loss


def train_point_seg(data_loader, models_bag, setup):
    train_size = len(data_loader)
    correct = 0.0

    total_loss = 0.0
    total = 0
    n = 0
    t = enumerate(data_loader)
    for i, (points, cls, seg, parts_range, parts_nb, real_points_mask) in t:
        # if i > 1 :
        #     continue

        models_bag["optimizer"].zero_grad()

        c_batch_size = points.shape[0]
        colors = []
        with torch.cuda.amp.autocast(bool(setup["mixed_precision"])):
            cls = cls.cuda()
            cls = Variable(cls)
            seg = seg.cuda()
            points = Variable(points).cuda()
            seg = Variable(seg)
            real_points_mask = real_points_mask.cuda()
            # seg = seg + 1 - parts_range[..., None].cuda().to(torch.int)

            criterion3d = nn.CrossEntropyLoss()
            outputs, feats,_ = models_bag["mvnetwork"](points.transpose(1,2))
            _, predicted = torch.max(outputs.data, dim=1)
            loss = criterion3d(outputs,seg.to(torch.long))


        if setup["mixed_precision"]:
            scaler.scale(loss).backward()
            scaler.unscale_(models_bag["optimizer"])

        else:
            loss.backward()
        total_loss += loss.detach().item()
        n += 1
        total += real_points_mask.sum().item()  # seg.size(0)*seg.size(1)
        correct += ((predicted == seg) & real_points_mask.to(torch.bool)).sum().item()

        if setup["mixed_precision"]:
            scaler.step(models_bag["optimizer"])
            scaler.update()
        else:
            models_bag["optimizer"].step()

        if setup["log_metrics"]:
            # step = get_current_step(models_bag["optimizer"])
            writer.add_scalar('Zoom/loss', loss.item(), i +
                              setup["c_epoch"]*train_size)
            writer.add_scalar('Zoom/MVCNN_grads', np.sum(np.array([np.sum(x.grad.cpu().numpy() ** 2) for x in filter(
                lambda y: type(y.grad) != type(None), models_bag["mvnetwork"].parameters())])), i + setup["c_epoch"]*train_size)

        if (i + 1) % setup["print_freq"] == 0:
            print("\tIter [%d/%d] Loss: %.4f" %
                  (i + 1, train_size, loss.item()))

    avg_loss = total_loss / n
    avg_train_acc = 100 * correct / total

    return avg_train_acc, avg_loss

# Validation and Testing
def evluate(data_loader, models_bag,  setup, is_test=False, retrieval=False):
    if is_test:
        load_checkpoint(setup, models_bag, setup["weights_file"])
    
    # Eval
    total = 0.0
    correct = 0.0

    total_loss = 0.0
    n = 0
    if retrieval:
        features_training = np.load(setup["feature_file"])	
        targets_training = np.load(setup["targets_file"])	
        N_retrieved = 1000 if "shapenetcore" in setup["data_dir"].lower() else len(features_training)

        features_training = lfda.transform(features_training)	
        # print("features_training.shape [training]", features_training.shape, targets_training.shape)	
        # from pykdtree.kdtree import KDTree	

        kdtree = scipy.spatial.KDTree(features_training)	
        all_APs = []

    views_record = ListDict(["azim", "elev", "dist","label","view_nb","exp_id"])
    t = tqdm(enumerate(data_loader), total=len(data_loader))	
    for i, (targets, meshes, points) in t:
        # if i > 1 :
        #     continue
        c_batch_size = targets.shape[0]
    # for i, (targets, meshes, points) in enumerate(data_loader):
        with torch.no_grad():
            # inputs = np.stack(inputs, axis=1)
            # inputs = torch.from_numpy(inputs)
            # if setup["custom_views_mode"] :
            #     rendered_images, _, azim, elev, dist = auto_render_meshes_custom_views(
            #     targets, meshes, points, models_bag, setup, )
            # else:
            with torch.cuda.amp.autocast(bool(setup["mixed_precision"])):
                with torch.cuda.amp.autocast(False):
                    azim, elev, dist = models_bag["mvtn"](points, c_batch_size=c_batch_size)
                    view_info = torch.cat([azim.unsqueeze(-1), elev.unsqueeze(-1)],dim=-1)
                    rendered_images, _ = models_bag["mvrenderer"](meshes, points,  azim=azim, elev=elev, dist=dist)
                targets = targets.cuda()
                targets = Variable(targets)
                # outputs = models_bag["mvnetwork"](rendered_images)[0]
                outputs, feat = models_bag["mvnetwork"](rendered_images) # return features as well	
                if retrieval:
                    feat = feat.cpu().numpy()	
                    feat = lfda.transform(feat)	
                    d, idx_closest = kdtree.query(feat, k=len(features_training))	
                    # loop over queries in the query	
                    for i_query_batch in range(feat.shape[0]):	
                        # details on retrieval-mAP: https://towardsdatascience.com/breaking-down-mean-average-precision-map-ae462f623a52#f9ce	
                        positives = targets_training[idx_closest[i_query_batch,:]] == targets[i_query_batch].cpu().numpy()	
                        # AP: numerator is cumulative of positives, zero-ing negatives 	
                        num = np.cumsum(positives)	
                        num[~positives] = 0	
                        # AP: denominator is number of retrieved shapes	
                        den = np.array([i+1 for i in range(len(features_training))])	
                        # AP: GTP is number of positive ground truth	
                        GTP = np.sum(positives)	
                        # print(den)	
                        AP = np.sum(num/den)/GTP	
                        all_APs.append(AP)
                
                loss = criterion(outputs, targets.to(torch.long))
                c_views = ListDict({"azim": azim.cpu().numpy().reshape(-1).tolist(), "elev": elev.cpu().numpy().reshape(-1).tolist(),
                        "dist": dist.cpu().numpy().reshape(-1).tolist(), "label": np.repeat(targets.cpu().numpy(), setup["nb_views"]).tolist(),
                        "view_nb": int(targets.cpu().numpy().shape[0]) * list(range(setup["nb_views"])),
                                    "exp_id": int(targets.cpu().numpy().shape[0]) * int(setup["nb_views"]) * [setup["exp_id"]]  } )
                views_record.extend(c_views)
                total_loss += loss.detach().item()
                n += 1
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted.cpu() == targets.cpu()).sum()


    # avg_test_acc = 100 * correct / total
    # avg_loss = total_loss / n

    avg_loss = total_loss / n	
    avg_test_acc = 100 * correct / total	
    if retrieval:
        retr_map = 100 * sum(all_APs)/len(all_APs)	
        print("avg_loss", avg_loss)    	
        print("avg_test_acc", avg_test_acc)	
        print("retr_map", retr_map)	
        return avg_test_acc, avg_loss, views_record , retr_map

    return avg_test_acc, avg_loss, views_record , 0


def evluate_part_seg(data_loader, models_bag,  setup, is_test=False):
    if is_test:
        load_checkpoint(setup, models_bag, setup["weights_file"])

    # Eval
    total = 0.0
    correct = 0.0

    total_loss = 0.0
    n = 0
    # shape_iou_tot = 0.
    # shape_iou_cnt = 0.
    total_empty = 0
    # part_intersect = torch.zeros(len(classes),)
    # part_union = torch.zeros(len(classes), )
    # categ_iou = torch.zeros(len(classes),)
    # categ_union = torch.zeros(len(classes), )
    categ_iou = [[] for _ in range(len(classes))]
    categ_union = [[] for _ in range(len(classes))]
    test_path = ""
    test_indx = 0
    visualize = False
    record = ListDict()

    t = tqdm(enumerate(data_loader), total=len(data_loader))
    for i, (points, cls, seg, parts_range, parts_nb, real_points_mask) in t:
        with torch.no_grad():
            # if i > 1 :
            #     continue
            c_batch_size = points.shape[0]
            colors = []
            if setup["use_normals"]:
                normals = points[:, :, 3:6]
                colors = (normals + 1.0) / 2.0
                colors = colors/torch.norm(colors, dim=-1,p=float(setup["color_normal_p"]))[..., None]
                points = points[:,:,0:3]
            with torch.cuda.amp.autocast(bool(setup["mixed_precision"])):
                with torch.cuda.amp.autocast(False):
                    azim, elev, dist = models_bag["mvtn"](points, c_batch_size=c_batch_size)
                    view_info = torch.cat([azim.unsqueeze(-1), elev.unsqueeze(-1)], dim=-1)
                    rendered_images, indxs, distance_weight_maps, _ = models_bag["mvrenderer"](None, points,  azim=azim, elev=elev, dist=dist, color=colors)
                # rendered_images, indxs, distance_weight_maps, azim, elev, _ = auto_render_parts(cls, None, points, models_bag, setup, color=colors, )
                cls = cls.cuda()
                cls = Variable(cls)
                points = Variable(points).cuda()
                seg = seg.cuda()
                seg = Variable(seg)
                real_points_mask = real_points_mask.cuda()

                # print(torch.unique(seg, False), cls, parts_range, parts_nb,"seg.shape",seg.shape)
                seg = seg + 1 - parts_range[..., None].cuda().to(torch.int)
                parts_range += 1  # the label 0 is reserved for bacgdround

                # save_batch_rendered_images(distance_weight_maps[:,:,0:3,...], test_path, "distance_weight_maps.jpg",)
                labels_2d, pix_to_face_mask  = models_bag["mvlifting"].compute_image_segment_label_points(points, batch_points_labels=seg, rendered_pix_to_point=indxs, )
                # print("label2d",labels_2d.shape, torch.unique(labels_2d,True),"seg.shape",seg.shape,"indxs.shape",indxs.shape)

                criterion = nn.CrossEntropyLoss(ignore_index=0)
                outputs , feats = models_bag["mvnetwork"](rendered_images, cls)
                loss2d = models_bag["mvnetwork"].get_loss(criterion, outputs, labels_2d, cls)
                _, predicted = torch.max(outputs.data, dim=1)

                views_weights = models_bag["mvlifting"].compute_views_weights(azim, elev, rendered_images, normals) 

                predictions_3d = models_bag["mvlifting"].lift_2D_to_3D(points, predictions_2d=svctomvc(outputs, nb_views=setup["nb_views"]), rendered_pix_to_point=indxs, views_weights=views_weights, cls=cls, parts_nb=parts_nb, view_info=view_info,early_feats=feats)
                criterion3d = nn.CrossEntropyLoss(ignore_index=0,)
                loss3d = models_bag["mvlifting"].get_loss_3d(criterion3d, predictions_3d, seg, cls)
                loss = setup["lambda_l2d"] * loss2d + setup["lambda_l3d"] * loss3d
                _, predictions_3d = torch.max(predictions_3d, dim=1) 
                predictions_3d = post_process_segmentation(points, predictions_3d, iterations=setup["post_process_iters"],K_neighbors=setup["post_process_k"])
            # predictions_3d = models_bag["mvlifting"].lift_2D_to_3D(points, predictions_2d=labels_2d, rendered_pix_to_point=indxs,views_weights=views_weights, cls=cls, parts_nb=parts_nb,view_info=view_info,early_feats=feats)
            # _, predictions_3d = torch.max(predictions_3d, dim=1)
            # if visualize :
            #     gt_images_path = os.path.join(test_path, "GT_renderings_{}.png".format(str(cls[test_indx].item())))
            #     pred_images_path = os.path.join(test_path, "GT_renderings_{}_pred.png".format(str(cls[test_indx].item())))
            #     _ = view_ptc_labels(rotation_matrix([1, 0, 0], 90).dot(points[test_indx].cpu().numpy().T).T, seg[test_indx].cpu().numpy(),COLOR_LABEL_VALUES_LIST, size=0.01, save_name=gt_images_path)
            #     _ = view_ptc_labels(rotation_matrix([1, 0, 0], 90).dot(points[test_indx].cpu().numpy(
            #     ).T).T, predictions_3d[test_indx].cpu().numpy(), COLOR_LABEL_VALUES_LIST, size=0.01, save_name=pred_images_path)
            #     save_batch_rendered_segmentation_images(labels_2d, test_path, "2d_labels.jpg")
            #     save_batch_rendered_segmentation_images(
            #         svtomv(predicted, nb_views=setup["nb_views"]), test_path, "2d_predictions.jpg")
            #     save_batch_rendered_images(rendered_images[:, :, 0:3, ...], test_path, "original.jpg",)
            
            total_loss += loss.detach().item()
            n += 1
            total += real_points_mask.sum().item()  # seg.size(0)*seg.size(1)
            total_empty += ((predictions_3d == 0) &real_points_mask.to(torch.bool)).sum().item()

            correct += ((predictions_3d == seg) & real_points_mask.to(torch.bool)).sum().item()
 
            # IOU calculations
            cur_shape_miou = batch_points_mIOU(seg - 1, predictions_3d - 1, real_points_mask.to(torch.bool), parts=parts_nb,)
            # print(cur_shape_miou.shape)
            if setup["record_extra_metrics"]:
                pixel_perc, point_perc, iou, valid_iou, cls_nb, part_nb = extra_IOU_metrics(
                    seg - 1, predictions_3d - 1, labels_2d-1, pix_to_face_mask, real_points_mask.to(torch.bool), cls, parts=parts_nb,)
                c_record = ListDict({"valid_iou":valid_iou,"cls_nb":cls_nb, "part_nb":part_nb,"pixel_perc": pixel_perc, "point_perc": point_perc, "iou": iou})
                record.extend(c_record)
                save_results(setup["views_file"], record)
            for cat in range(len(classes)):
                cat_cur_shape_miou = cur_shape_miou[ (cls == cat).view(-1)]
                categ_iou[cat] += cat_cur_shape_miou.cpu().numpy().tolist()


                # categ_iou[cat] += (I * (cls == cat ).to(torch.long)  ).sum()
                # categ_union[cat] += (U * (cls == cat).to(torch.long)).sum()
            # shape_iou_tot += cur_shape_miou.sum().item()
            # shape_iou_cnt += c_batch_size
    # shape_mIoU = 100 * shape_iou_tot / shape_iou_cnt
    # part_iou = part_intersect/ part_union
    print("The number of objects per class: ", list(zip(classes, [len(x) for x in categ_iou])))
    all_ious = []
    mean_cat_iou = []
    for cat in range(len(classes)):
        all_ious += categ_iou[cat]
        mean_cat_iou.append(np.mean(np.array(categ_iou[cat]), axis=-1))
    mean_inst_iou = 100 * np.mean(np.array(all_ious))
    # cat_iou = categ_iou / categ_union
    print("The mIOU per class average: ", list(zip(classes, mean_cat_iou)))

    mean_cat_iou = 100 * np.mean(np.array(mean_cat_iou))
    # mean_inst_iou = 100.*  part_iou

        


    avg_loss = total_loss / n
    avg_test_acc = 100 * correct / float(total)
    point_coverage = 100 - 100 * total_empty / float(total) 

    return avg_test_acc, mean_cat_iou, mean_inst_iou, avg_loss, point_coverage


def evaluate_pls(data_loader, models_bag, setup):
    train_size = len(data_loader)
    total = 0.0
    correct = 0.0

    total_loss = 0.0
    n = 0
    # torch.autograd.set_detect_anwomaly(True)
    t = tqdm(enumerate(data_loader), total=len(data_loader))

    for i, (targets, meshes, points) in t:
        with torch.no_grad():
            c_batch_size, nb_points, _ = points.shape
            cls = torch.ones((c_batch_size)).cuda().to(torch.int)
            labels_3d = targets[..., None].repeat(
                1, nb_points).to(torch.long).cuda()
            # if i > 2 :
            #     continue

            models_bag["optimizer"].zero_grad()
            if setup["is_learning_views"]:
                models_bag["mvtn_optimizer"].zero_grad()
            # if setup["is_learning_points"]:
            #     models_bag["fe_optimizer"].zero_grad()

            # inputs = np.stack(inputs, axis=1)
            # inputs = torch.from_numpy(inputs)
            with torch.cuda.amp.autocast(bool(setup["mixed_precision"])):
                with torch.cuda.amp.autocast(False):
                    azim, elev, dist = models_bag["mvtn"](
                        points, c_batch_size=c_batch_size)
                    view_info = torch.cat(
                        [azim.unsqueeze(-1), elev.unsqueeze(-1)], dim=-1)
                    rendered_images, indxs, distance_weight_maps, _ = models_bag["mvrenderer"](
                        None, points,  azim=azim, elev=elev, dist=dist, color=None)
                rendered_images = regualarize_rendered_views(
                    rendered_images, setup["view_reg"], setup["augment_training"], setup["crop_ratio"])
                targets = targets.cuda()
                targets = Variable(targets)
                outputs, feats = models_bag["mvnetwork"](rendered_images, cls)
                _, nb_views, _, h, w = rendered_images.shape

                criterion2d = nn.CrossEntropyLoss()
                loss2d = criterion2d(svctomvc(outputs, nb_views=nb_views).transpose(
                    1, 2), targets.to(torch.long)[..., None][..., None][..., None].repeat(1, nb_views, h, w))

                _, predictions_2d = torch.max(outputs.data, dim=1)
                views_weights = models_bag["mvlifting"].compute_views_weights(
                    azim, elev, rendered_images, points)
                predictions_3d = models_bag["mvlifting"].lift_2D_to_3D(points.cuda(), predictions_2d=svctomvc(
                    outputs, nb_views=setup["nb_views"]), rendered_pix_to_point=indxs, views_weights=views_weights, cls=cls, parts_nb=len(classes) * cls, view_info=None, early_feats=feats)
                criterion3d = nn.CrossEntropyLoss()
                loss3d = criterion3d(predictions_3d, labels_3d)
                loss = setup["lambda_l2d"] * loss2d + setup["lambda_l3d"] * loss3d
                # _, predictions_3d = torch.max(predictions_3d, dim=1)
                _, predicted = torch.max(torch.mean(predictions_3d, dim=2), dim=-1)


            total_loss += loss.detach().item()
            n += 1
            total += targets.size(0)
            correct += (predicted.cpu() == targets.cpu()).sum().item()



    avg_loss = total_loss / n
    avg_test_acc = 100 * correct / total

    return avg_test_acc, avg_loss


def evluate_point_seg(data_loader, models_bag,  setup, is_test=False):
    if is_test:
        load_checkpoint(setup, models_bag, setup["weights_file"])

    # Eval
    total = 0.0
    correct = 0.0

    total_loss = 0.0
    n = 0
    total_empty = 0

    categ_iou = [[] for _ in range(len(classes))]


    t = tqdm(enumerate(data_loader), total=len(data_loader))
    for i, (points, cls, seg, parts_range, parts_nb, real_points_mask) in t:
        with torch.no_grad():
            # if i > 1 :
            #     continue
            c_batch_size = points.shape[0]
            with torch.cuda.amp.autocast(bool(setup["mixed_precision"])):
                cls = cls.cuda()
                cls = Variable(cls)
                seg = seg.cuda()
                points = Variable(points).cuda()
                seg = Variable(seg)
                real_points_mask = real_points_mask.cuda()

                # seg = seg + 1 - parts_range[..., None].cuda().to(torch.int)
                # criterion3d = nn.CrossEntropyLoss(
                #     ignore_index=0, reduction="none" if setup["balanced_object_loss"] else "mean")
                criterion3d = nn.CrossEntropyLoss()
                outputs, feats,_ = models_bag["mvnetwork"](points.transpose(1, 2))
                loss = criterion3d(outputs, seg.to(torch.long))
                _, predicted = torch.max(outputs.data, dim=1)
                predicted  = reduce_part_seg_prediction(outputs,parts_nb,parts_range) # to take only the predictions of corresponfding object
                seg = seg  - parts_range[..., None].cuda().to(torch.int)

            total_loss += loss.detach().item()
            n += 1
            total += real_points_mask.sum().item()  # seg.size(0)*seg.size(1)
            total_empty += ((predicted == 0) &
                            real_points_mask.to(torch.bool)).sum().item()

            correct += ((predicted == seg) &
                        real_points_mask.to(torch.bool)).sum().item()
            # cur_shape_miou = batch_points_mIOU(seg - 1, predicted - 1, real_points_mask.to(torch.bool), parts=parts_nb,)
            cur_shape_miou = batch_points_mIOU(seg , predicted , real_points_mask.to(torch.bool), parts=max(parts_per_class)*torch.ones_like(cls),)


            # IOU calculations

            for cat in range(len(classes)):
                cat_cur_shape_miou = cur_shape_miou[(cls == cat).view(-1)]
                categ_iou[cat] += cat_cur_shape_miou.cpu().numpy().tolist()

    print("The number of objects per class: ", list(
        zip(classes, [len(x) for x in categ_iou])))
    all_ious = []
    mean_cat_iou = []
    for cat in range(len(classes)):
        all_ious += categ_iou[cat]
        mean_cat_iou.append(np.mean(np.array(categ_iou[cat]), axis=-1))
    mean_inst_iou = 100 * np.mean(np.array(all_ious))
    # cat_iou = categ_iou / categ_union
    print("The mIOU per class average: ", list(zip(classes, mean_cat_iou)))

    mean_cat_iou = 100 * np.mean(np.array(mean_cat_iou))
    # mean_inst_iou = 100.*  part_iou

    avg_loss = total_loss / n
    avg_test_acc = 100 * correct / float(total)

    return avg_test_acc, mean_cat_iou, mean_inst_iou, avg_loss

def evluate_rotation_robustness(data_loader, models_bag,  setup, max_degs=180.0,):
    # Eval
    total = 0.0
    correct = 0.0

    total_loss = 0.0
    n = 0
    for i, (targets, meshes, points) in enumerate(tqdm(data_loader)):
        with torch.no_grad():
            
            c_batch_size = targets.shape[0]
            rot_axis = [0.0, 1.0, 0.0]
            angles = [np.random.rand()*2.0*max_degs -
                      max_degs for _ in range(c_batch_size)]

            rotR = np.array([rotation_matrix(rot_axis, angle)
                             for angle in angles])
            meshes = Meshes(
                verts=[torch.mm(torch.from_numpy(rotR[ii]).to(torch.float), msh.verts_list()[
                                0].transpose(0, 1)).transpose(0, 1).cuda() for ii, msh in enumerate(meshes)],
                faces=[msh.faces_list()[0].cuda() for msh in meshes],
                textures=None)
            max_vert = meshes.verts_padded().shape[1]

            meshes.textures = Textures(verts_rgb=torch.ones(
                (c_batch_size, max_vert, 3)) .cuda())

            points = torch.bmm(torch.from_numpy(rotR).to(torch.float), points.transpose(1, 2)).transpose(1, 2)
            azim, elev, dist = models_bag["mvtn"](points, c_batch_size=c_batch_size)
            rendered_images, _ = models_bag["mvrenderer"](None, points,  azim=azim, elev=elev, dist=dist)
            targets = targets.cuda()
            targets = Variable(targets)
            outputs = models_bag["mvnetwork"](rendered_images)[0]
            loss = criterion(outputs, targets)

            total_loss += loss.detach().item()
            n += 1
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted.cpu() == targets.cpu()).sum()

    avg_test_acc = 100 * correct / total
    avg_loss = total_loss / n
    return avg_test_acc, avg_loss

def compute_features(data_loader, models_bag, setup):	
    # if is_test:	
        # load_checkpoint(setup, models_bag, setup["weights_file"])	
    print("compute training metrics and store training features")	
    # Eval	
    total = 0.0	
    correct = 0.0	
    total_loss = 0.0	
    n = 0	
    feat_list=[]	
    target_list=[]	
    views_record = ListDict(["azim", "elev", "dist","label","view_nb","exp_id"])	
    t = tqdm(enumerate(data_loader), total=len(data_loader))	
    for i, (targets, meshes, points) in t:	
        # if i > 2:
        #     continue 

        c_batch_size = targets.shape[0]
        with torch.no_grad():	


            # if i > 5: break	
            # inputs = np.stack(inputs, axis=1)
            # inputs = torch.from_numpy(inputs)
            with torch.cuda.amp.autocast(bool(setup["mixed_precision"])):
                with torch.cuda.amp.autocast(False):
                    azim, elev, dist = models_bag["mvtn"](points, c_batch_size=c_batch_size)
                    view_info = torch.cat([azim.unsqueeze(-1), elev.unsqueeze(-1)],dim=-1)
                    rendered_images, _ = models_bag["mvrenderer"](meshes, points,  azim=azim, elev=elev, dist=dist)
                # rendered_images = regualarize_rendered_views(rendered_images, setup["view_reg"], setup["augment_training"], setup["crop_ratio"])            
                targets = targets.cuda()
                targets = Variable(targets)	
                outputs, feat = models_bag["mvnetwork"](rendered_images)	
                    
                feat_list.append(feat.cpu().numpy())	
                target_list.append(targets.cpu().numpy())	
                    
                loss = criterion(outputs, targets.to(torch.long))
                c_views = ListDict({"azim": azim.cpu().numpy().reshape(-1).tolist(), "elev": elev.cpu().numpy().reshape(-1).tolist(),	
                        "dist": dist.cpu().numpy().reshape(-1).tolist(), "label": np.repeat(targets.cpu().numpy(), setup["nb_views"]).tolist(),	
                        "view_nb": int(targets.cpu().numpy().shape[0]) * list(range(setup["nb_views"])),	
                                    "exp_id": int(targets.cpu().numpy().shape[0]) * int(setup["nb_views"]) * [setup["exp_id"]]  } )	
                views_record.extend(c_views)	
                total_loss += loss.detach().item()	
                n += 1	
                _, predicted = torch.max(outputs.data, 1)	
                total += targets.size(0)	
                correct += (predicted.cpu() == targets.cpu()).sum()	
            t.set_description(f"{i} - Acc {100 * correct / total :2.2f} - Loss {total_loss / n:2.6f}")	
    features = np.concatenate(feat_list)	
    targets = np.concatenate(target_list)	
    avg_test_acc = 100 * correct / total	
    avg_loss = total_loss / n	
    return features, targets



   

# Training / Eval loop
if setup["resume"] or "test" in setup["run_mode"]:
    load_checkpoint(setup, models_bag, setup["weights_file"])

if "cls" in setup["run_mode"]:
    # to return a mpaaing function of the renderer used in 2D to #d unprojection
    models_bag["mvrenderer"].return_mapping = False

    if setup["log_metrics"]:
        writer = SummaryWriter(setup["logs_dir"])
        writer.add_hparams(setup, {"hparams/best_acc": 0.0})
    setup["best_retr_map"] = 0

    for epoch in range(setup["start_epoch"], n_epochs):
        setup["c_epoch"] = epoch
        print('\n-----------------------------------')
        print('Epoch: [%d/%d]' % (epoch+1, n_epochs))
        start = time.time()
        if "train" in setup["run_mode"]:
            models_bag["mvnetwork"].train()
            models_bag["mvtn"].train()
            models_bag["mvrenderer"].train()
            # models_bag["feature_extractor"].train()
            avg_train_acc, avg_train_loss = train(train_loader, models_bag, setup)
            print('Time taken: %.2f sec.' % (time.time() - start))
            print('\ttrain acc: %.2f - train Loss: %.4f' %(avg_train_acc, avg_train_loss))
            if setup["log_metrics"]:
                writer.add_scalar('Loss/train', avg_train_loss, epoch)
                writer.add_scalar('Accuracy/train', avg_train_acc, epoch)

        models_bag["mvnetwork"].eval()
        models_bag["mvtn"].eval()
        models_bag["mvrenderer"].eval()
        # models_bag["feature_extractor"].eval()
        avg_test_acc, avg_loss, views_record,_ = evluate(
            val_loader, models_bag, setup)

        print('\nEvaluation:')
        print('\tVal Acc: %.2f - val Loss: %.4f' % (avg_test_acc, avg_loss))
        print('\tCurrent best val acc: %.2f' % setup["best_acc"])
        if setup["log_metrics"]:
            writer.add_scalar('Loss/val', avg_loss, epoch)
            writer.add_scalar('Accuracy/val', avg_test_acc, epoch)



        # Log epoch to tensorboard
        # See log using: tensorboard --logdir='logs' --port=6006 ######################################
        # util.logEpoch(logger, mvnetwork, epoch + 1, avg_loss, avg_test_acc) #############################################
        saveables = {'epoch': epoch + 1,
                     'state_dict': models_bag["mvnetwork"].state_dict(),
                     "mvtn": models_bag["mvtn"].state_dict(),
                    #  "feature_extractor": models_bag["feature_extractor"].state_dict(),
                    'acc': avg_test_acc.item(),
                    'best_acc': setup["best_acc"],
                     'optimizer': models_bag["optimizer"].state_dict(),
                     'mvtn_optimizer': None if not setup["is_learning_views"] else models_bag["mvtn_optimizer"].state_dict(),
                    #  'fe_optimizer': None if not setup["is_learning_points"] else models_bag["fe_optimizer"].state_dict(),
                    }
        if setup["save_all"]:
            save_checkpoint(saveables, setup, views_record,setup["weights_file"])
        # Save mvnetwork
        if avg_test_acc.item() >= setup["best_acc"] and "test" not in setup["run_mode"]:
            print('\tSaving checkpoint - Acc: %.2f' % avg_test_acc)
            saveables["best_acc"] = avg_test_acc.item()
            setup["best_loss"] = avg_loss
            setup["best_acc"] = avg_test_acc.item()
            save_checkpoint(saveables, setup, views_record,
                            setup["weights_file"], ignore_saving_models=setup["ignore_saving_models"])

        # Decaying Learning Rate
        if (epoch + 1) % setup["lr_decay_freq"] == 0:
            lr *= setup["lr_decay"]
            models_bag["optimizer"] = torch.optim.AdamW(
                models_bag["mvnetwork"].parameters(), lr=lr)
            print('Learning rate:', lr)
        if (epoch + 1) % setup["plot_freq"] == 0 or "test" in setup["run_mode"]:
            for indx,ii in enumerate( PLOT_SAMPLE_NBS_CLS):
                (targets, meshes, points) = dset_val[ii]
                c_batch_size = 1
                cameras_root_folder = os.path.join(setup["cameras_dir"],str(indx))
                check_folder(cameras_root_folder)
                renderings_root_folder = os.path.join(setup["renderings_dir"], str(indx))
                check_folder(renderings_root_folder)
                cameras_path = os.path.join(
                    cameras_root_folder, "MV_cameras_{}.jpg".format(str(epoch + 1)))
                images_path = os.path.join(
                    renderings_root_folder, "MV_renderings_{}.jpg".format(str(epoch + 1)))
                #points = torch.from_numpy(points)
                azim, elev, dist = models_bag["mvtn"](points[None, ...], c_batch_size=c_batch_size)           
            models_bag["mvrenderer"].render_and_save([meshes], points[None,...], azim=azim, elev=elev, dist=dist, images_path=images_path,cameras_path=cameras_path,)
    if setup["run_full_eval"] and "test" not in setup["run_mode"]:
        orig_nb_views = copy.copy(setup["nb_views"])
        setup, models_bag = reinitilize_setup(
            setup, models_bag, new_nb_views=setup["full_eval_views"])
        val_loader = DataLoader(dset_val, batch_size=int(
            2), shuffle=False, num_workers=4, collate_fn=collate_fn)
        models_bag["mvrenderer"].return_mapping = False
        avg_test_acc, avg_loss, views_record,_ = evluate(val_loader, models_bag, setup,)
        setup["best_acc"] = max(avg_test_acc.item(), setup["best_acc"])
        setup["nb_views"] = orig_nb_views
        save_checkpoint(None, setup, None, None, ignore_saving_models=True)
    
    # if setup["log_metrics"]:
    #     writer.add_hparams(setup, {"hparams/best_acc": setup["best_acc"]})

if "retr" in setup["run_mode"]:
    models_bag["mvrenderer"].return_mapping = False # to return a mpaaing function of the renderer used in 2D to #d unprojection

    print('\nEvaluation:')
    models_bag["mvnetwork"].eval()
    models_bag["mvtn"].eval()
    models_bag["mvrenderer"].eval()

    # models_bag["feature_extractor"].eval()

    # extract features for training (if does not exist yet)
    os.makedirs(os.path.dirname(setup["feature_file"]),exist_ok=True)
    if not os.path.exists(setup["feature_file"]) or not os.path.exists(setup["targets_file"]):
        features, targets = compute_features(train_loader,models_bag, setup)
        np.save(setup["feature_file"],features)
        np.save(setup["targets_file"],targets)

    # reduce Features:
    LFDA_reduction_file = os.path.join(setup["features_dir"], "reduction_LFDA.pkl")
    if not os.path.exists(LFDA_reduction_file):
        from metric_learn import LFDA
        features = np.load(setup["feature_file"])
        targets = np.load(setup["targets_file"])
        lfda = LFDA(n_components=128)
        lfda.fit(features, targets)
        with open(LFDA_reduction_file, "wb") as fobj:
            pkl.dump(lfda, fobj)

    with open(LFDA_reduction_file, "rb") as fobj:
        lfda = pkl.load(fobj)

    avg_test_acc, avg_test_loss, _, avg_test_retr_mAP = evluate(val_loader, models_bag, setup, retrieval=True)
    print('\tVal Acc: %.2f - val retr-mAP: %.2f - val Loss: %.4f' %
          (avg_test_acc, avg_test_retr_mAP, avg_test_loss))
    print('\tCurrent best val acc: %.2f' % setup["best_acc"])
    setup["best_retr_map"] = avg_test_retr_mAP
    save_checkpoint(None, setup, None, None, ignore_saving_models=True)
    # if setup["log_metrics"]:
    #     writer.add_hparams(setup, {"hparams/best_retr_map": setup["best_retr_map"]})
    
if "pls" in setup["run_mode"]:
    # to return a mpaaing function of the renderer used in 2D to #d unprojection
    if setup["log_metrics"]:
        writer = SummaryWriter(setup["logs_dir"])
        writer.add_hparams(setup, {"hparams/best_acc": 0.0})
    setup["best_retr_map"] = 0

    for epoch in range(setup["start_epoch"], n_epochs):
        setup["c_epoch"] = epoch
        print('\n-----------------------------------')
        print('Epoch: [%d/%d]' % (epoch+1, n_epochs))
        start = time.time()
        if "train" in setup["run_mode"]:
            models_bag["mvnetwork"].train()
            models_bag["mvtn"].train()
            models_bag["mvrenderer"].train()
            models_bag["mvlifting"].train()
            # models_bag["feature_extractor"].train()
            avg_train_acc, avg_train_loss = train_pls(
                train_loader, models_bag, setup)
            print('Time taken: %.2f sec.' % (time.time() - start))
            print('\ttrain acc: %.2f - train Loss: %.4f' %
                  (avg_train_acc, avg_train_loss))
            if setup["log_metrics"]:
                writer.add_scalar('Loss/train', avg_train_loss, epoch)
                writer.add_scalar('Accuracy/train', avg_train_acc, epoch)

        models_bag["mvnetwork"].eval()
        models_bag["mvtn"].eval()
        models_bag["mvrenderer"].eval()
        models_bag["mvlifting"].eval()

        avg_test_acc, avg_loss = evaluate_pls(
            val_loader, models_bag, setup)

        print('\nEvaluation:')
        print('\tVal Acc: %.2f - val Loss: %.4f' % (avg_test_acc, avg_loss))
        print('\tCurrent best val acc: %.2f' % setup["best_acc"])
        if setup["log_metrics"]:
            writer.add_scalar('Loss/val', avg_loss, epoch)
            writer.add_scalar('Accuracy/val', avg_test_acc, epoch)

        # Log epoch to tensorboard
        # See log using: tensorboard --logdir='logs' --port=6006 ######################################
        # util.logEpoch(logger, mvnetwork, epoch + 1, avg_loss, avg_test_acc) #############################################
        saveables = {'epoch': epoch + 1,
                     'state_dict': models_bag["mvnetwork"].state_dict(),
                     "mvtn": models_bag["mvtn"].state_dict(),
                     "mvlifting": models_bag["mvlifting"].state_dict(),
                     #  "feature_extractor": models_bag["feature_extractor"].state_dict(),
                     'acc': avg_test_acc,
                     'best_acc': setup["best_acc"],
                     'optimizer': models_bag["optimizer"].state_dict(),
                     'mvtn_optimizer': None if not setup["is_learning_views"] else models_bag["mvtn_optimizer"].state_dict(),
                     'mlp_optimizer': None if not setup["learning_lifting"] else models_bag["mlp_optimizer"].state_dict(),
                     }
        if setup["save_all"]:
            save_checkpoint(saveables, setup, None,
                            setup["weights_file"])
        # Save mvnetwork
        if avg_test_acc >= setup["best_acc"] and "test" not in setup["run_mode"]:
            print('\tSaving checkpoint - Acc: %.2f' % avg_test_acc)
            saveables["best_acc"] = avg_test_acc
            setup["best_loss"] = avg_loss
            setup["best_acc"] = avg_test_acc
            save_checkpoint(saveables, setup, None,
                            setup["weights_file"], ignore_saving_models=setup["ignore_saving_models"])

        # Decaying Learning Rate
        if (epoch + 1) % setup["lr_decay_freq"] == 0:
            lr *= setup["lr_decay"]
            models_bag["optimizer"] = torch.optim.AdamW(
                models_bag["mvnetwork"].parameters(), lr=lr)
            print('Learning rate:', lr)
        if (epoch + 1) % setup["plot_freq"] == 0 or "test" in setup["run_mode"]:
            for indx, ii in enumerate(PLOT_SAMPLE_NBS_CLS):
                (targets, meshes, points) = dset_val[ii]
                c_batch_size = 1
                cameras_root_folder = os.path.join(
                    setup["cameras_dir"], str(indx))
                check_folder(cameras_root_folder)
                renderings_root_folder = os.path.join(
                    setup["renderings_dir"], str(indx))
                check_folder(renderings_root_folder)
                cameras_path = os.path.join(
                    cameras_root_folder, "MV_cameras_{}.jpg".format(str(epoch + 1)))
                images_path = os.path.join(
                    renderings_root_folder, "MV_renderings_{}.jpg".format(str(epoch + 1)))
                #points = torch.from_numpy(points)
                azim, elev, dist = models_bag["mvtn"](
                    points[None, ...], c_batch_size=c_batch_size)
            models_bag["mvrenderer"].render_and_save(
                [meshes], points[None, ...], azim=azim, elev=elev, dist=dist, images_path=images_path, cameras_path=cameras_path,)
    if setup["run_full_eval"] and "test" not in setup["run_mode"]:
        orig_nb_views = copy.copy(setup["nb_views"])
        setup, models_bag = reinitilize_setup(
            setup, models_bag, new_nb_views=setup["full_eval_views"])
        val_loader = DataLoader(dset_val, batch_size=int(
            2), shuffle=False, num_workers=4, collate_fn=collate_fn)
        models_bag["mvrenderer"].return_mapping = False
        avg_test_acc, avg_loss = evaluate_pls(val_loader, models_bag, setup)
        setup["best_acc"] = max(avg_test_acc, setup["best_acc"])
        setup["nb_views"] = orig_nb_views
        save_checkpoint(None, setup, None, None, ignore_saving_models=True)

if "part" in setup["run_mode"]:
    is_test = False

    if setup["log_metrics"]:
        writer = SummaryWriter(setup["logs_dir"])
        writer.add_hparams(setup, {"hparams/best_acc": 0.0 , "hparams/best_inst_iou": 0.0,"hparams/best_cat_iou": 0.0})
    setup["best_inst_iou"] = 0.0
    setup["best_cat_iou"] = 0.0

    for epoch in range(setup["start_epoch"], n_epochs):
        setup["c_epoch"] = epoch
        print('\n-----------------------------------')
        print('Epoch: [%d/%d]' % (epoch+1, n_epochs))
        start = time.time()
        if "train" in setup["run_mode"]:
            models_bag["mvnetwork"].train()
            models_bag["mvtn"].train()
            models_bag["mvrenderer"].train()
            models_bag["mvlifting"].train()


            # models_bag["feature_extractor"].train()
            avg_train_acc, avg_train_loss = train_part_seg(train_loader, models_bag, setup)
            print('Time taken: %.2f sec.' % (time.time() - start))
            print('\ttrain pixel acc: %.2f - train 2D Loss: %.4f ' %(avg_train_acc, avg_train_loss))
            if setup["log_metrics"]:
                writer.add_scalar('Loss/train', avg_train_loss, epoch)
                writer.add_scalar('Accuracy/train', avg_train_acc, epoch)
        else:
            is_test = True
        models_bag["mvnetwork"].eval()
        models_bag["mvtn"].eval()
        models_bag["mvrenderer"].eval()
        models_bag["mvlifting"].eval()

        # models_bag["feature_extractor"].eval()
        avg_test_acc, mean_cat_iou_test, mean_inst_iou_test, avg_loss, point_coverage = evluate_part_seg(
            val_loader, models_bag, setup, is_test=is_test)
        print('\nEvaluation:')
        print('\tVal point Acc: %.2f - val category-avg iou: %.2f - val instance-avg iou: %.2f - val Coverage: %.2f - val 2D Loss: %.3f ' %
              (avg_test_acc, mean_cat_iou_test, mean_inst_iou_test, point_coverage,avg_loss))
        # print(mean_inst_iou_test)
        print('\tCurrent best acc: {:.2f}  - best category_avg miou: {:.2f} - best instance_avg miou: {:.2f}'.format(
            setup["best_acc"], setup["best_cat_iou"], setup["best_inst_iou"]))
        if setup["log_metrics"]:
            writer.add_scalar('Loss/val', avg_loss, epoch)
            writer.add_scalar('Accuracy/val', avg_test_acc, epoch)
            writer.add_scalar('test_mIOU/categ', mean_cat_iou_test, epoch)
            writer.add_scalar('test_mIOU/inst',mean_inst_iou_test, epoch)


        saveables = {'epoch': epoch + 1,
                     'state_dict': models_bag["mvnetwork"].state_dict(),
                     "mvtn": models_bag["mvtn"].state_dict(),
                     "mvlifting": models_bag["mvlifting"].state_dict(),
                    #  "feature_extractor": models_bag["feature_extractor"].state_dict(),
                     'acc': avg_test_acc,
                     'best_acc': setup["best_acc"],
                     'best_inst_iou': setup["best_inst_iou"],
                     'best_cat_iou': setup["best_cat_iou"],
                     'optimizer': models_bag["optimizer"].state_dict(),
                     'mlp_optimizer': None if not setup["learning_lifting"] else models_bag["mlp_optimizer"].state_dict(),
                     }
        if setup["save_all"]:
            save_checkpoint(saveables, setup, None,
                            setup["weights_file"])
        # Save mvnetwork
        if mean_cat_iou_test > setup["best_cat_iou"] and "test" not in setup["run_mode"]:
            print('\tSaving checkpoint - Acc: %.2f' % avg_test_acc)
            saveables["best_acc"] = avg_test_acc
            setup["best_loss"] = avg_loss
            setup["best_acc"] = avg_test_acc
            setup["point_coverage"] = point_coverage
            setup["best_cat_iou"] = mean_cat_iou_test
            setup["best_inst_iou"] = mean_inst_iou_test
            save_checkpoint(saveables, setup, None,setup["weights_file"], ignore_saving_models=setup["ignore_saving_models"])
        if setup["ignore_visualizations"] and "test" in setup["run_mode"]:
            print("finshed testing ")
            sys.exit()
        # Decaying Learning Rate
        if (epoch + 1) % setup["lr_decay_freq"] == 0:
            lr *= setup["lr_decay"]
            lr_mlp *= setup["lr_decay"]
            models_bag["optimizer"] = torch.optim.AdamW(models_bag["mvnetwork"].parameters(), lr=lr)
            if setup["learning_lifting"]:
                models_bag["mlp_optimizer"] = torch.optim.AdamW(models_bag["mvlifting"].parameters(), lr=lr_mlp)
            print('Learning rate:', lr,)
        if (epoch + 1) % setup["plot_freq"] == 0 or "test" in setup["run_mode"]:
            for indx, ii in enumerate(PLOT_SAMPLE_NBS_SEG):
                c_batch_size = 1 
                (points, cls, seg, parts_range,parts_nb, real_points_mask) = dset_val[ii]
                given_labels = list(range(0, parts_nb),)
                points = torch.from_numpy(points)[None, ...].cuda()
                colors = []
                if setup["use_normals"]:
                    normals = points[:,:,3:6]
                    colors = (normals + 1.0) / 2.0
                    colors = colors/torch.norm(colors, dim=-1,p=float(setup["color_normal_p"]))[..., None]
                    points = points[:,:,0:3]
                seg = torch.from_numpy(seg)[None, ...].cuda()
                real_points_mask = torch.from_numpy(real_points_mask)[
                    None, ...].cuda()
                parts_range = torch.Tensor([parts_range]).to(torch.int)
                parts_nb = torch.Tensor([parts_nb]).to(torch.int)
                cls = torch.Tensor([cls]).cuda()
                seg = seg + 1 - parts_range[..., None].cuda().to(torch.int)
                parts_range += 1  # the label 0 is reserved for bacgdround


                renderings_root_folder = os.path.join(setup["renderings_dir"], str(indx))
                check_folder(renderings_root_folder)

                # rendered_images, indxs, distance_weight_maps, azim, elev, _ = auto_render_parts(cls, None, points, models_bag, setup,color=colors, )
                azim, elev, dist = models_bag["mvtn"](points, c_batch_size=c_batch_size)
                view_info = torch.cat([azim.unsqueeze(-1), elev.unsqueeze(-1)],dim=-1)
                rendered_images, indxs, distance_weight_maps, _ = models_bag["mvrenderer"](None, points,  azim=azim, elev=elev, dist=dist, color=colors)
                labels_2d, pix_to_face_mask  = models_bag["mvlifting"].compute_image_segment_label_points(points, batch_points_labels=seg, rendered_pix_to_point=indxs, )


                outputs , feats = models_bag["mvnetwork"](rendered_images,cls)
                _, predicted = torch.max(outputs, dim=1)
 
                views_weights = models_bag["mvlifting"].compute_views_weights(azim, elev, rendered_images, normals)
                predictions_3d = models_bag["mvlifting"].lift_2D_to_3D(points, predictions_2d=svctomvc(outputs, nb_views=setup["nb_views"]), rendered_pix_to_point=indxs,views_weights=views_weights, cls=cls, parts_nb=parts_nb,view_info=view_info,early_feats=feats)
                _, predictions_3d = torch.max(predictions_3d, dim=1)
                predictions_3d = post_process_segmentation(points, predictions_3d, iterations=setup["post_process_iters"], K_neighbors=setup["post_process_k"])
                predictions_3d_projected ,_= models_bag["mvlifting"].compute_image_segment_label_points(points, batch_points_labels=predictions_3d.to(torch.int), rendered_pix_to_point=indxs, )
                save_batch_rendered_segmentation_images(labels_2d, renderings_root_folder, "GT_renderings_{}.jpg".format(str(epoch + 1)),)
                save_batch_rendered_segmentation_images(svtomv(predicted, nb_views=setup["nb_views"])* (~pix_to_face_mask).to(torch.long)[:,:,0,...],
                    renderings_root_folder, "PRED_2D_renderings_{}.jpg".format(str(epoch + 1)), given_labels=given_labels)
                save_batch_rendered_segmentation_images(predictions_3d_projected, renderings_root_folder, "PRED_3D_renderings_{}.jpg".format(str(epoch + 1)))
                save_batch_rendered_images(rendered_images[:, :, 0:3, ...], renderings_root_folder, "original_renderings_{}.jpg".format(str(epoch + 1)),)
                if "test" in setup["run_mode"]:
                    _ = view_ptc_labels(rotation_matrix([1, 0, 0], 90).dot(points[0].cpu().numpy().T).T, seg[0].cpu().numpy(), COLOR_LABEL_VALUES_LIST, size=0.01, save_name=os.path.join(renderings_root_folder, "GT_final.png".format(str(epoch + 1))))
                    _ = view_ptc_labels(rotation_matrix([1, 0, 0], 90).dot(points[0].cpu().numpy().T).T, predictions_3d[0].cpu().numpy(), COLOR_LABEL_VALUES_LIST, size=0.01, save_name=os.path.join(renderings_root_folder, "PRED_final.png".format(str(epoch + 1))))

                cur_shape_miou = batch_points_mIOU(seg - 1, predictions_3d - 1, real_points_mask.to(torch.bool), parts=parts_nb,)
                print("object {} of class {} has mIOU: {:.1f}".format(ii,classes[int(cls.item())],100*cur_shape_miou.item()))

            if "test" in setup["run_mode"]:
                print("finshed testing ")
                sys.exit()
    if setup["run_full_eval"] and "test" not in setup["run_mode"]:
        orig_nb_views = copy.copy(setup["nb_views"])
        setup, models_bag = reinitilize_setup(setup, models_bag,new_nb_views=setup["full_eval_views"])
        val_loader = DataLoader(dset_val, batch_size=int(2),shuffle=False, num_workers=4, collate_fn=collate_fn)
        avg_test_acc, mean_cat_iou_test, mean_inst_iou_test, avg_loss, point_coverage = evluate_part_seg(val_loader, models_bag, setup, is_test=is_test)
        setup["best_acc"] = max(avg_test_acc, setup["best_acc"])
        setup["best_cat_iou"] = max(mean_cat_iou_test, setup["best_cat_iou"])
        setup["best_inst_iou"] = max(mean_inst_iou_test, setup["best_inst_iou"])
        setup["nb_views"] = orig_nb_views
        save_checkpoint(None, setup, None,None, ignore_saving_models=True)

    if setup["log_metrics"]:
        writer.add_hparams(setup, {
                           "hparams/best_acc": setup["best_acc"], "hparams/best_inst_iou": setup["best_inst_iou"], "hparams/best_cat_iou": setup["best_cat_iou"]})


elif setup["run_mode"] == "test_occ":
    # to return a mpaaing function of the renderer used in 2D to #d unprojection
    models_bag["mvrenderer"].return_mapping = False

    models_bag["mvnetwork"].eval()
    models_bag["mvtn"].eval()
    models_bag["mvrenderer"].eval()
    # models_bag["feature_extractor"].eval()
    if "modelnet" not in setup["data_dir"].lower():
        raise Exception('Occlusion is only supported froom ModelNet now ')
    from tqdm import tqdm
    torch.multiprocessing.set_sharing_strategy('file_system')

    print('\Evaluatiing om the cropped data :')

    override = True
    networks_list = ["Voint"]
    factor_list = [-0.75, -0.5, -0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.5, 0.75]
    axis_list = [0, 1, 2]

    setup_keys = ["network", "batch", "factor", "axis"]
    setups = ListDict(setup_keys)
    results = ListDict(["prediction", "class"])
    for network in networks_list:
        if network == "PointNet":
            setup["shape_extractor"] = "PointNet"
            point_network = PointNet(40, alignment=True).cuda()
        elif network == "DGCNN":
            setup["shape_extractor"] = "DGCNN"
            point_network = SimpleDGCNN(40).cuda()
        if network in ["DGCNN", "PointNet"]:
            point_network.eval()
            load_point_ckpt(
                point_network,  setup["shape_extractor"],  ckpt_dir='./checkpoint')
        exp_id = "chopping_{}".format(network)
        save_file = os.path.join(setup["results_dir"], exp_id+".csv")
        if not os.path.isfile(save_file) or override:
            t = tqdm(enumerate(val_loader), total=len(val_loader))
            for ii, (targets, meshes, orig_pts) in t:
                c_batch_size = len(meshes)
                with torch.no_grad():
                    azim, elev, dist = models_bag["mvtn"](orig_pts, c_batch_size=c_batch_size)
                    rendered_images, _ = models_bag["mvrenderer"](meshes, orig_pts,  azim=azim, elev=elev, dist=dist)
                    targets = targets.cuda()
                    for factor in factor_list:
                        for axis in axis_list:
                            c_setup = {"network": network,
                                        "batch": ii, "factor": factor, "axis": axis}
                            [setups.append(c_setup)
                                for ii in range(c_batch_size)]
                            chopped_pts = chop_ptc(orig_pts.cpu().numpy(), factor, axis=axis)
                            chopped_pts = torch.from_numpy(chopped_pts)
                            if network not in ["PointNet", "DGCNN"]:
                                azim, elev, dist = models_bag["mvtn"](chopped_pts, c_batch_size=c_batch_size)
                                # save_grid(image_batch=rendered_images[0, ...],save_path=os.path.join(setup["results_dir"],"renderings","{}_{}.jpg".format(network,factor)), nrow=setup["nb_views"])
                                rendered_images, _ = models_bag["mvrenderer"](meshes, chopped_pts,  azim=azim, elev=elev, dist=dist)
                                outputs, _ = models_bag["mvnetwork"](
                                    rendered_images)
                            else:
                                chopped_pts = chopped_pts.transpose(1, 2).cuda()
                                outputs = point_network(chopped_pts)[
                                    0].view(c_batch_size, -1)
                            _, predictions = torch.max(outputs.data, 1)
                            c_result = ListDict({"prediction": predictions.cpu().numpy(
                            ).tolist(), "class": targets.cpu().numpy().tolist()})
                            results.extend(c_result)
                            save_results(save_file, results+setups)
                        # raise Exception("just checking the visualization")

elif setup["run_mode"] == "test_rot":
    # to return a mpaaing function of the renderer used in 2D to #d unprojection
    models_bag["mvrenderer"].return_mapping = False

    # setup["results_file"] = os.path.join(
    #     setup["results_dir"], setup["exp_id"]+"_robustness_{}.csv".format(str(int(setup["max_degs"]))))
    # setup["return_points_saved"] = True
    # assert os.path.isfile(setup["weights_file"]
    #                         ), 'Error: no checkpoint file found!'

    # loaded_info = load_results(os.path.join(
    #     setup["results_dir"], setup["exp_id"]+"_accuracy.csv"))
    # setup["start_epoch"] = loaded_info["start_epoch"][0]
    # setup["nb_views"] = loaded_info["nb_views"][0]
    # setup["views_config"] = loaded_info["views_config"][0]

    # print('\nEvaluating Robustness:')
    # mvtn = MVTN(setup["nb_views"], views_config=setup["views_config"],canonical_elevation=setup["canonical_elevation"], canonical_distance=setup["canonical_distance"],
    #             shape_features_size=setup["features_size"], transform_distance=setup["transform_distance"], input_view_noise=setup["input_view_noise"], shape_extractor=setup["shape_extractor"], screatch_feature_extractor=setup["screatch_feature_extractor"]).cuda()
    # models_bag["mvtn"] = mvtn
    # load_checkpoint_robustness(setup, models_bag, setup["weights_file"])

    models_bag["mvnetwork"].eval()
    models_bag["mvtn"].eval()
    models_bag["mvrenderer"].eval()
    acc_list = []
    for _ in range(setup["repeat_exp"]):
        avg_test_acc, _ = evluate_rotation_robustness(val_loader, models_bag, setup, max_degs=setup["max_degs"])
        acc_list.append(avg_test_acc.item())
    setup["best_acc"] = np.mean(acc_list)
    print("exp: {} \tVal Acc: {:.2f} ".format(
        setup["exp_id"], setup["best_acc"]))
    setup_dict = ListDict(list(setup.keys()))
    # save_results(setup["results_file"], setup_dict.append(setup))

