import torch
from torch.autograd import Variable
import numpy as np
import os
import torchvision
import timm
import sys
from util import *
import shutil
from torch import nn
from torch._six import inf
from einops import rearrange, repeat
from einops.layers.torch import Rearrange


from pytorch3d.structures import Meshes, Pointclouds
from pytorch3d.renderer.mesh import Textures
from pytorch3d.renderer import (
    OpenGLPerspectiveCameras, look_at_view_transform, look_at_rotation,
    RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
    SoftSilhouetteShader, HardPhongShader, PointLights, SoftPhongShader,
    HardFlatShader, HardGouraudShader, SoftGouraudShader,
    OpenGLOrthographicCameras,
    PointsRasterizationSettings,
    PointsRenderer,
    PointsRasterizer,
    AlphaCompositor,
    NormWeightedCompositor, DirectionalLights)
# ORTHOGONAL_THRESHOLD = 1e-6
EXAHSTION_LIMIT = 20
# EPSILON = 0.00001


def initialize_setup(setup):
    SHAPE_FEATURES_SIZE = {"logits": 40, "post_max": 1024,
                           "transform_matrix": 64*64, "pre_linear": 512, "post_max_trans": 1024 + 64*64, "logits_trans": 40+64*64, "pre_linear_trans": 512+64*64}

    setup["features_size"] = SHAPE_FEATURES_SIZE[setup["features_type"]]
    if setup["exp_id"] == "random":
        setup["exp_id"] = random_id()
    check_folder(os.path.join(setup["results_dir"], setup["exp_set"]))
    setup["results_dir"] = os.path.join(
        setup["results_dir"], setup["exp_set"], setup["exp_id"])
    setup["cameras_dir"] = os.path.join(
        setup["results_dir"], setup["cameras_dir"])
    setup["renderings_dir"] = os.path.join(
        setup["results_dir"], setup["renderings_dir"])
    setup["verts_dir"] = os.path.join(setup["results_dir"], "verts")
    setup["checkpoint_dir"] = os.path.join(setup["results_dir"], "checkpoint")
    setup["features_dir"] = os.path.join(setup["results_dir"], "features")
    setup["logs_dir"] = os.path.join(setup["results_dir"], setup["logs_dir"])
    setup["feature_file"] = os.path.join(
        setup["features_dir"], "features_training.npy")
    setup["targets_file"] = os.path.join(
        setup["features_dir"], "targets_training.npy")

    check_folder(setup["results_dir"])
    check_folder(setup["cameras_dir"])
    check_folder(setup["renderings_dir"])
    check_folder(setup["logs_dir"])
    check_folder(setup["verts_dir"])
    check_folder(setup["checkpoint_dir"])
    setup["best_acc"] = 0.0
    setup["best_loss"] = 0.0
    setup["start_epoch"] = 0
    setup["results_file"] = os.path.join(
        setup["results_dir"], setup["exp_id"]+"_accuracy.csv")
    setup["views_file"] = os.path.join(
        setup["results_dir"], setup["exp_id"]+"_views.csv")
    setup["weights_file"] = os.path.join(
        setup["checkpoint_dir"], setup["exp_id"]+"_checkpoint.pt")
    setup["is_learning_views"] = setup["views_config"] in ["learned_offset",
                                                             "learned_direct", "learned_spherical", "learned_random", "learned_transfer"]
    setup["is_learning_points"] = setup["is_learning_views"] and (
        setup["return_points_saved"] or setup["return_points_sampled"])
    setup["learning_lifting"] = setup["use_mlp_classifier"] or setup["lifting_method"] in ["mlp", "gcn", "transformer","gat"]
    for k, v in setup.items():
        if isinstance(v, bool):
            setup[k] = int(v)


def model_name_from_setup(setup):
    imagenet_ptrain = "" if not "21k" in setup["pretraining_mode"] else "_in21k"
    swin_window = "" if "swin" not in setup["vit_variant"] else "_window{}".format(
        setup["swin_window_size"])
    timm_model_name = "{}_{}_patch{}{}_{}{}".format(
        setup["vit_variant"], setup["vit_model_size"], setup["patch_size"], swin_window, setup["image_size"], imagenet_ptrain)
    return timm_model_name



# MVTN_regressor = Sequential(MLP([b+2*M, b, b, 5 * M, 2*M], activation="relu", dropout=0.5,
                                # batch_norm=True), MLP([2*M, 2*M], activation=None, dropout=0, batch_norm=False), nn.Tanh())
def applied_transforms(images_batch, crop_ratio=0.3):
    """
    a pyutroch transforms that can be applied batchwise 
    """
    N, C, H, W = images_batch.shape
    padd = torch.nn.ReplicationPad2d(int((1+crop_ratio)*H)-H)
    images_batch = RandomHorizontalFlip()(images_batch)
    images_batch = RandomCrop(H)(padd(images_batch))
    return images_batch


def super_batched_op(dim,batched_ops,batched_tensor,*args,**kwargs):
    """
    convert a batch operation in pytorch to work on 5 dims (N,C,H,W) + X , where `dim` will dictate the extra dimension X that will be put on dimensions N  
    """
    return unbatch_tensor(batched_ops(batch_tensor(batched_tensor, dim=dim, squeeze=True), *args, **kwargs), dim=dim, unsqueeze=True, batch_size=batched_tensor.shape[0])


def check_and_correct_rotation_matrix(R, T, nb_trials, azim, elev, dist):
    exhastion = 0
    while not check_valid_rotation_matrix(R):
            exhastion += 1
            R, T = look_at_view_transform(dist=batch_tensor(dist.T, dim=1, squeeze=True), elev=batch_tensor(elev.T + 90.0 * torch.rand_like(elev.T, device=elev.device),
                                                                                                                                 dim=1, squeeze=True), azim=batch_tensor(azim.T + 180.0 * torch.rand_like(azim.T, device=elev.device), dim=1, squeeze=True))
            # print("PROBLEM is fixed {} ? : ".format(exhastion),check_valid_rotation_matrix(R))
            if not check_valid_rotation_matrix(R) and exhastion > nb_trials:
                sys.exit("Remedy did not work")
    return R , T


class MyPatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
        super().__init__()
        import timm.models.layers as tlayers

        img_size = tlayers.to_2tuple(img_size)
        patch_size = tlayers.to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.patch_grid = (img_size[0] // patch_size[0],
                           img_size[1] // patch_size[1])
        self.num_patches = self.patch_grid[0] * self.patch_grid[1]

        self.proj = nn.Conv2d(in_chans, embed_dim,
                              kernel_size=patch_size, stride=patch_size)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x
class MyViT(nn.Module):

    def __init__(self, model_name, pretrained=True, vit_agr_type="cls"):
        super(MyViT, self).__init__()
        self.vit_agr_type = vit_agr_type
        self.vit = timm.create_model(model_name, pretrained=pretrained)

    # Set your own forward pass
    def forward_features(self, x):
        x = self.vit.patch_embed(x)
        # stole cls_tokens impl from Phil Wang, thanks
        cls_token = self.vit.cls_token.expand(x.shape[0], -1, -1)
        if self.vit.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)
        else:
            x = torch.cat((cls_token, self.vit.dist_token.expand(
                x.shape[0], -1, -1), x), dim=1)
        x = self.vit.pos_drop(x + self.vit.pos_embed)
        x = self.vit.blocks(x)
        x = self.vit.norm(x)

        if self.vit_agr_type == "cls":
            return x[:, 0], x[:, 1]
        elif self.vit_agr_type == "max":
            return torch.max(x[:,1::,:], dim=1)[0], x[:, 1]
        elif self.vit_agr_type == "mean":
            return torch.mean(x[:, 1::, :], dim=1), x[:, 1]

    def forward(self, x):
        x = self.forward_features(x)
        if self.vit.head_dist is not None:
            x, x_dist = self.vit.head(x[0]), self.vit.head_dist(
                x[1])  # x must be a tuple
            if self.vit.training and not torch.jit.is_scripting():
                # during inference, return the average of both classifier predictions
                return x, x_dist
            else:
                return (x + x_dist) / 2
        else:
            x = self.vit.head(x[0])
        return x
# def render_meshes(meshes, color, azim, elev, dist, lights, setup, background_color=(1.0, 1.0, 1.0), ):
#     c_batch_size = len(meshes)
#     verts = [msh.verts_list()[0].cuda() for msh in meshes]
#     faces = [msh.faces_list()[0].cuda() for msh in meshes]
#     # faces = [torch.cat((fs, torch.flip(fs, dims=[1])),dim=0) for fs in faces]
#     new_meshes = Meshes(
#         verts=verts,
#         faces=faces,
#         textures=None)
#     max_vert = new_meshes.verts_padded().shape[1]

#     # print(len(new_meshes.faces_list()[0]))
#     new_meshes.textures = Textures(
#         verts_rgb=color.cuda()*torch.ones((c_batch_size, max_vert, 3)).cuda())
#     # Create a Meshes object for the teapot. Here we have only one mesh in the batch.
#     R, T = look_at_view_transform(dist=batch_tensor(dist.T, dim=1, squeeze=True), elev=batch_tensor(
#         elev.T, dim=1, squeeze=True), azim=batch_tensor(azim.T, dim=1, squeeze=True))
#     R, T = check_and_correct_rotation_matrix(R, T, EXAHSTION_LIMIT, azim, elev, dist)

#     cameras = OpenGLPerspectiveCameras(device="cuda:{}".format(torch.cuda.current_device()), R=R, T=T)
#     camera = OpenGLPerspectiveCameras(device="cuda:{}".format(torch.cuda.current_device()), R=R[None, 0, ...],
#                                       T=T[None, 0, ...])

#     # camera2 = OpenGLPerspectiveCameras(device=device, R=R[None, 2, ...],T=T[None, 2, ...])
#     # print(camera2.get_camera_center())
#     raster_settings = RasterizationSettings(
#         image_size=setup["image_size"],
#         blur_radius=0.0,
#         faces_per_pixel=1,
#         # bin_size=None, #int
#         # max_faces_per_bin=None,  # int
#         # perspective_correct=False,
#         # clip_barycentric_coords=None, #bool
#         cull_backfaces=setup["cull_backfaces"],
#     )
#     renderer = MeshRenderer(
#         rasterizer=MeshRasterizer(
#             cameras=camera, raster_settings=raster_settings),
#         shader=HardPhongShader(blend_params=BlendParams(background_color=background_color
#                                                         ), device=lights.device, cameras=camera, lights=lights)
#     )
#     new_meshes = new_meshes.extend(setup["nb_views"])

#     # compute output
#     # print("after rendering .. ", rendered_images.shape)

#     rendered_images = renderer(new_meshes, cameras=cameras, lights=lights)

#     rendered_images = unbatch_tensor(
#         rendered_images, batch_size=setup["nb_views"], dim=1, unsqueeze=True).transpose(0, 1)
#     # print(rendered_images[:, 100, 100, 0])

#     rendered_images = rendered_images[..., 0:3].transpose(2, 4).transpose(3, 4)
#     return rendered_images, cameras


# def render_points(points, color, azim, elev, dist, setup, background_color=(0.0, 0.0, 0.0), ):
#     c_batch_size = azim.shape[0]

#     point_cloud = Pointclouds(points=points.to(torch.float), features=color *
#                               torch.ones_like(points, dtype=torch.float)).cuda()

#     # print(len(new_meshes.faces_list()[0]))
#     # Create a Meshes object for the teapot. Here we have only one mesh in the batch.
#     R, T = look_at_view_transform(dist=batch_tensor(dist.T, dim=1, squeeze=True), elev=batch_tensor(
#         elev.T, dim=1, squeeze=True), azim=batch_tensor(azim.T, dim=1, squeeze=True))
#     R, T = check_and_correct_rotation_matrix(R, T, EXAHSTION_LIMIT, azim, elev, dist)

#     cameras = OpenGLOrthographicCameras(device="cuda:{}".format(torch.cuda.current_device()), R=R, T=T, znear=0.01)
#     raster_settings = PointsRasterizationSettings(
#         image_size=setup["image_size"],
#         radius=setup["points_radius"],
#         points_per_pixel=setup["points_per_pixel"]
#     )

#     renderer = PointsRenderer(
#         rasterizer=PointsRasterizer(
#             cameras=cameras, raster_settings=raster_settings),
#         compositor=NormWeightedCompositor()
#     )
#     point_cloud = point_cloud.extend(setup["nb_views"])
#     point_cloud.scale_(batch_tensor(1.0/dist.T, dim=1,squeeze=True)[..., None][..., None])

#     rendered_images = renderer(point_cloud)
#     rendered_images = unbatch_tensor(
#         rendered_images, batch_size=setup["nb_views"], dim=1, unsqueeze=True).transpose(0, 1)

#     rendered_images = rendered_images[..., 0:3].transpose(2, 4).transpose(3, 4)
#     return rendered_images, cameras


# def auto_render_meshes(targets, meshes, points, models_bag, setup, ):
#     # inputs = inputs.cuda(device)
#     # inputs = Variable(inputs)
#     c_batch_size = len(targets)
#     # if the model in test phase use white color
#     if setup["object_color"] == "random" and not models_bag["mvtn"].training:
#         color = torch_color("white")
#     else:
#         color = torch_color(setup["object_color"],max_lightness=True, epsilon=EPSILON)
#     background_color = torch_color(setup["background_color"], max_lightness=True, epsilon=EPSILON).cuda()
#     # shape_features = models_bag["feature_extractor"](
#     #     points, c_batch_size=c_batch_size).cuda()
#     azim, elev, dist = models_bag["mvtn"](points, c_batch_size=c_batch_size)

#     # lights = PointLights(
#     #     device=None, location=((0, 0, 0),))
#     if not setup["pc_rendering"]:
#         lights = DirectionalLights(
#             device=background_color.device, direction=models_bag["mvtn"].light_direction(azim, elev, dist))

#         rendered_images, cameras = render_meshes(
#             meshes=meshes, color=color, azim=azim, elev=elev, dist=dist, lights=lights, setup=setup, background_color=background_color)
#     else:
#         rendered_images, cameras = render_points(
#             points=points, color=color, azim=azim, elev=elev, dist=dist, setup=setup, background_color=background_color)
#     return rendered_images, cameras, azim, elev, dist




# def auto_render_meshes_custom_views(targets, meshes, points, models_bag, setup, ):
#     c_batch_size = len(targets)
#     if setup["object_color"] == "random" and not models_bag["mvtn"].training:
#         color = torch_color("white")
#     else:
#         color = torch_color(setup["object_color"],max_lightness=True)

#     # shape_features = models_bag["feature_extractor"](
#     #     points, c_batch_size=c_batch_size).cuda()
#     azim, elev, dist = models_bag["mvtn"](points, batch_size=c_batch_size)

#     for i, target in enumerate(targets.numpy().tolist()):
#         azim[i] = torch.from_numpy(np.array(models_bag["azim_dict"][target]))
#         elev[i] = torch.from_numpy(np.array(models_bag["elev_dict"][target]))

#     # lights = PointLights(
#     #     device=device, location=((0, 0, 0),))
#     if not setup["pc_rendering"]:
#         lights = DirectionalLights(
#             targets.device, direction=models_bag["mvtn"].light_direction(azim, elev, dist))

#         rendered_images, cameras = render_meshes(
#             meshes=meshes, color=color, azim=azim, elev=elev, dist=dist, lights=lights, setup=setup,)
#     else:
#         rendered_images, cameras = render_points(
#             points=points, azim=azim, elev=elev, dist=dist, setup=setup,)
#     rendered_images = nn.functional.dropout2d(
#         rendered_images, p=setup["view_reg"], training=models_bag["mvtn"].training)

#     return rendered_images, cameras, azim, elev, dist


# def auto_render_and_save_images_and_cameras(targets, meshes, points, images_path, cameras_path, models_bag, setup, ):
#     # inputs = np.stack(inputs, axis=0)
#     # inputs = torch.from_numpy(inputs)
#     with torch.no_grad():
#         if not setup["return_points_saved"] and not setup["return_points_sampled"]:
#             points = torch.from_numpy(points)
#         targets = torch.tensor(targets)[None]
#         # correction_factor = torch.tensor(correction_factor)
#         rendered_images, cameras, _, _, _ = auto_render_meshes(
#             targets, [meshes], points[None, ...], models_bag, setup,)
#     # print("before saving .. ",rendered_images.shape)
#     save_grid(image_batch=rendered_images[0, ...],
#               save_path=images_path, nrow=setup["nb_views"])
#     save_cameras(cameras, save_path=cameras_path, scale=0.22, dpi=200)


# def auto_render_and_analyze_images(targets, meshes, points, images_path, models_bag, setup, ):
#     # inputs = np.stack(inputs, axis=0)
#     # inputs = torch.from_numpy(inputs)
#     with torch.no_grad():
#         if not setup["return_points_saved"] and not setup["return_points_sampled"]:
#             points = torch.from_numpy(points)
#         targets = torch.tensor(targets)
#         # correction_factor = torch.tensor(correction_factor)
#         rendered_images, _, _, _, _ = auto_render_meshes(
#             targets, [meshes], points[None, ...], models_bag, setup,)
#     # print("before saving .. ",rendered_images.shape)
#     save_grid(image_batch=rendered_images[0, ...],
#               save_path=images_path, nrow=setup["nb_views"])
#     mask = rendered_images != 1
#     img_avg = (rendered_images*mask).sum(dim=(0, 1, 2, 3, 4)) / \
#         mask.sum(dim=(0, 1, 2, 3, 4))
#     # print("Original  ", rendered_images.mean(dim=(0,1,2,3,4)), "\n IMG avg : ", img_avg)
#     # print(img_avg.cpu().numpy())

#     return float(img_avg.cpu().numpy())


def regualarize_rendered_views(rendered_images, dropout_p=0, augment_training=False, crop_ratio=0.3):
    #### To perform dropout on the views
    rendered_images = nn.functional.dropout2d(
        rendered_images, p=dropout_p, training=True)

    if augment_training:
        rendered_images = super_batched_op(
            1, applied_transforms, rendered_images, crop_ratio=crop_ratio)
    return rendered_images

def clip_grads_(parameters, max_norm, norm_type=2):
    r"""Clips gradient norm of an iterable of parameters and zero them if nan.

    The norm is computed over all gradients together, as if they were
    concatenated into a single vector. Gradients are modified in-place.

    Arguments:
        parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
            single Tensor that will have gradients normalized
        max_norm (float or int): max norm of the gradients
        norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
            infinity norm.

    Returns:
        Total norm of the parameters (viewed as a single vector).
    """
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    parameters = list(filter(lambda p: p.grad is not None, parameters))
    max_norm = float(max_norm)
    norm_type = float(norm_type)
    for p in parameters:
        p.grad.detach().data = zero_nans(p.grad.detach().data)
    if norm_type == inf:
        total_norm = max(p.grad.detach().abs().max() for p in parameters)
    else:
        total_norm = torch.norm(torch.stack(
            [torch.norm(p.grad.detach(), norm_type) for p in parameters]), norm_type)
    clip_coef = max_norm / (total_norm + 1e-6)
    if clip_coef < 1:
        for p in parameters:
            p.grad.detach().mul_(clip_coef)
    return total_norm








def test_point_network(model,criterion, data_loader):
    total = 0.0
    correct = 0.0
    total_loss = 0.0
    n = 0
    from tqdm import tqdm
    for i, (targets,_, points,_) in enumerate(tqdm(data_loader)):
        with torch.no_grad():
            points = points.transpose(1, 2).cuda()
            targets = targets.cuda()
            targets = Variable(targets)
            # print(rendered_images[:,0,:,100,100])
            logits, shape_features, trans = model(points)
            loss = criterion(logits, targets)

            total_loss += loss
            n += 1
            _, predicted = torch.max(logits.data, 1)
            total += targets.size(0)
            correct += (predicted.cpu() == targets.cpu()).sum()

    avg_test_acc = 100 * correct / total
    avg_loss = total_loss / n

    return avg_test_acc, avg_loss




def save_checkpoint(state, setup, views_record, weights_file, ignore_saving_models=False):
    if not ignore_saving_models:
        torch.save(state, weights_file)
    setup_dict = ListDict(list(setup.keys()))
    save_results(setup["results_file"], setup_dict.append(setup))
    if views_record is not None:
        save_results(setup["views_file"], views_record)


def load_checkpoint(setup, models_bag, weights_file,ignore_optimizer=False):
    # Load checkpoint.
    print('\n==> Loading checkpoint..')
    assert os.path.isfile(weights_file
                          ), 'Error: no checkpoint file found!'

    checkpoint = torch.load(weights_file)
    setup["best_acc"] = checkpoint['best_acc']
    setup["start_epoch"] = checkpoint['epoch']
    models_bag["mvnetwork"].load_state_dict(checkpoint['state_dict'])
    if setup["is_learning_views"]:
        models_bag["mvtn"].load_state_dict(checkpoint['mvtn'])
        if not ignore_optimizer:
            models_bag["mvtn_optimizer"].load_state_dict(checkpoint['mvtn_optimizer'])
    if setup["learning_lifting"]:
        models_bag["mvlifting"].load_state_dict(checkpoint['mvlifting'])
        if not ignore_optimizer:
            models_bag["mlp_optimizer"].load_state_dict(checkpoint['mlp_optimizer'])

    # if setup["is_learning_points"]:
    #     models_bag["feature_extractor"].load_state_dict(
    #         checkpoint['feature_extractor'])
    #     models_bag["fe_optimizer"].load_state_dict(checkpoint['fe_optimizer'])
    # if "late_fusion_mode" in setup and setup["late_fusion_mode"]:
    #     models_bag["classifier"].load_state_dict(checkpoint['classifier'])
    #     models_bag["cls_optimizer"].load_state_dict(checkpoint['cls_optimizer'])
    #     models_bag["point_network"].load_state_dict(
    #         checkpoint['point_network'])
    #     models_bag["fe_optimizer"].load_state_dict(checkpoint['fe_optimizer'])

    models_bag["optimizer"].load_state_dict(checkpoint['optimizer'])





def load_checkpoint_robustness(setup, models_bag, weights_file):
    # Load checkpoint.
    print('\n==> Loading checkpoint..')
    assert os.path.isfile(weights_file
                          ), 'Error: no checkpoint file found!'

    checkpoint = torch.load(weights_file)
    models_bag["mvnetwork"].load_state_dict(checkpoint['state_dict'])
    if setup["is_learning_views"]:
        models_bag["mvtn"].load_state_dict(
            checkpoint['mvtn'])
    # if setup["is_learning_points"]:
    #     models_bag["feature_extractor"].load_state_dict(
    #         checkpoint['feature_extractor'])





def mvtosv(x): return rearrange(x, 'b m h w -> (b m) h w ')
def mvctosvc(x): return rearrange(x, 'b m c h w -> (b m) c h w ')
def svtomv(x,nb_views=1): return rearrange(x, '(b m) h w -> b m h w',m=nb_views)
def svctomvc(x,nb_views=1): return rearrange(x, '(b m) c h w -> b m c h w',m=nb_views)




def extra_IOU_metrics(points_GT, points_predictions, pixels_GT, pixel_mask, points_mask,object_class, parts,):
    """
    a funciton to calculate IOUs  for bacth of point clouds `points_predictions` based on the ground truth `points_GT` and record more metrics as well based on pixels
    """
    bs , p_nb = points_GT.shape
    _, v,h,w = pixels_GT.shape
    cur_shape_ious = []
    cur_parts_valid = []
    part_nb = [] ; cls_nb = []
    pixel_perc = []  ;  point_perc = []
    for cl in range(torch.max(parts).item()):
        cur_gt_mask = (points_GT == cl) & points_mask  # -1 to remove the background class laabel
        cur_pred_mask = (points_predictions == cl) & points_mask

        I = (cur_pred_mask & cur_gt_mask).sum(dim=-1)
        U = (cur_pred_mask | cur_gt_mask).sum(dim=-1)

        cur_shape_ious.extend((100.0* I/(U + 1e-7) ).cpu().numpy().tolist() )
        cur_parts_valid.extend((U > 0).to(torch.int32).cpu().numpy().tolist())
        cls_nb.extend(object_class.squeeze().cpu().numpy().tolist())
        part_nb.extend(bs*[cl])
        pixel_perc.extend((100.0*(pixels_GT == cl).sum(dim=-1).sum(dim=-1).sum(dim=-1).to(
            torch.float).cpu().numpy() / (pixel_mask.sum().item())).tolist())
        point_perc.extend((100.0*cur_gt_mask.sum(dim=-1).to(torch.float).cpu().numpy() / points_mask.sum().item()).tolist())

    return  pixel_perc, point_perc, cur_shape_ious,cur_parts_valid, cls_nb ,part_nb 


def batched_index_select_(x, idx):
    """
    This can be used for neighbors features fetching
    Given a pointcloud x, return its k neighbors features indicated by a tensor idx.
    :param x: torch.Size([batch_size, num_dims, num_vertices, 1])
    :param index: torch.Size([batch_size, num_vertices, k])
    :return: torch.Size([batch_size, num_dims, num_vertices, k])
    """

    batch_size, num_dims, num_vertices = x.shape[:3]
    _, all_combo, k = idx.shape
    idx_base = torch.arange(
        0, batch_size, device=idx.device).view(-1, 1, 1) * num_vertices
    idx = idx + idx_base
    idx = idx.view(-1)

    x = x.transpose(2, 1).contiguous()
    feature = x.view(batch_size * num_vertices, -1)[idx, :]
    feature = feature.view(batch_size, all_combo, k,
                           num_dims).permute(0, 3, 1, 2)
    return feature


def batched_index_select_parts(x, idx):
    """
    This can be used for neighbors features fetching
    Given a pointcloud x, return its k neighbors features indicated by a tensor idx.
    :param x: torch.Size([batch_size, num_vertices , 1])
    :param index: torch.Size([batch_size, num_views, points_per_pixel,H,W])
    :return: torch.Size([batch_size, _vertices, k])
    """

    batch_size, num_view, num_nbrs, H, W = idx.shape[:5]
    _, num_dims, num_vertices = x.shape

    idx = rearrange(idx, 'b m p h w -> b (m h w) p')
    x = x[..., None]
    feature = batched_index_select_(x, idx)
    feature = rearrange(feature, 'b d (m h w) p -> b m d p h w',
                        m=num_view, h=H, w=W, d=num_dims)
    return feature


def knn(x, k):
    """
    Given point features x [B, C, N, 1], and number of neighbors k (int)
    Return the idx for the k neighbors of each point. 
    So, the shape of idx: [B, N, k]
    """
    with torch.no_grad():
        x = x.squeeze(-1)
        inner = -2 * torch.matmul(x.transpose(2, 1), x)
        xx = torch.sum(x ** 2, dim=1, keepdim=True)
        inner = -xx - inner - xx.transpose(2, 1)

        idx = inner.topk(k=k, dim=-1)[1]  # (batch_size, num_points, k)
    return idx


def knnq2ref(q, ref, k):
    """
    Given query point features x [B, C, N, 1] and ref point features x [B, C, M, 1]  and number of neighbors k (int)
    Return the idx for the k neighbors in ref for all query points . 
    So, the shape of idx: [B, N, k]
    """
    B, C, M, _ = ref.shape
    with torch.no_grad():
        q = q.repeat((1, 1, 1, M))
        ref = ref
        dist = torch.norm(a - ref.transpose(2, 3), dim=1, p=2)
        idx = dist.topk(k=k, dim=-1)[1]  # (batch_size, num_points, k)
    return idx


def post_process_segmentation(point_set, predictions_3d, iterations=1, K_neighbors=1):
    """
    a function to fill empty points in point cloud `point_set` with the labels of their nearest neghbors in `predictions_3d` in an iterative fashion 
    """
    for iter in range(iterations):
        emptry_points = predictions_3d == 0
        nbr_indx = knn(point_set.transpose(
            1, 2)[..., None], iter*K_neighbors + 2)
        nbr_labels = batched_index_select_(
            predictions_3d[..., None].transpose(1, 2)[..., None], nbr_indx)
        # only look at the closest neighbor to fetch its labels
        nbr_labels = torch.mode(nbr_labels[:, 0, :, 1::], dim=-1)[0]
        predictions_3d[emptry_points] = nbr_labels[emptry_points]
    return predictions_3d


def compute_metrics(labels_3d, face_labels, label_range):
    """
    Compute some metrics values for faces.
    Args:
        labels_3d: (face_num, ), predicted face labels. -1 means invalid.
        face_labels: (face_num, ), ground-truth face labels.
        label_range: (2, ), start and end label. End label is included.
    Returns:
        face_cov: face coverage.
        face_acc: face accuracy in all faces.
        cov_face_acc: face accuracy only in covered faces.
        IoU: IoU for one mesh.
    """
    face_cov = np.sum(labels_3d != -1) / labels_3d.shape[0]
    face_acc = np.sum(labels_3d == face_labels) / labels_3d.shape[0]
    cov_faces = labels_3d != -1
    cov_face_acc = np.sum(labels_3d[cov_faces] ==
                          face_labels[cov_faces]) / np.sum(cov_faces)

    # compute IoU
    IoU_part_sum = 0.0
    for class_idx in range(label_range[0], label_range[1] + 1):
        location_gt = (face_labels == class_idx)
        location_pred = (labels_3d == class_idx)
        I_locations = np.logical_and(location_gt, location_pred)
        U_locations = np.logical_or(location_gt, location_pred)
        I = np.sum(I_locations) + np.finfo(np.float32).eps
        U = np.sum(U_locations) + np.finfo(np.float32).eps
        IoU_part_sum += I / U

    IoU = IoU_part_sum / (label_range[1] - label_range[0] + 1)
    return face_cov, face_acc, cov_face_acc, IoU


from models.multi_view import *
from models.voint import *
from models.pointnet import *
# from models.segformer import segformer
# from models.swinsegment import swinsegment
def get_mvnetwork(setup, num_classes,num_parts=1):
    TOTAL_NB_PARTS = 50

    if setup["mvnetwork"] == 'resnet' and "cls" in setup["run_mode"]:
        depth2featdim = {18: 512, 34: 512, 50: 2048, 101: 2048, 152: 2048}
        assert setup["depth"] in list(
            depth2featdim.keys()), "the requested resnt depth not available"
        mvnetwork = torchvision.models.__dict__[
            "resnet{}".format(setup["depth"])](not(setup["pretraining_mode"] == "scratch"))
        mvnetwork.fc = nn.Sequential()
        lifting_net = nn.Sequential() if setup["lifting_method"] != "mlp" else Seq(*[Rearrange('B M C -> B C M'), Conv1dLayer([depth2featdim[setup["depth"]], 512, 512, 512, depth2featdim[setup["depth"]]], act='relu', norm=True, bias=True), Rearrange('B C M-> B M C')])
        mvnetwork = MVAgregate(mvnetwork, agr_type=setup["mv_agr_type"],
                               feat_dim=depth2featdim[setup["depth"]], num_classes=num_classes, lifting_net=lifting_net)
        print('Using ' + setup["mvnetwork"] + str(setup["depth"]))


    elif "vit" in setup["mvnetwork"] :
        size2featdim = {"base": 768, "large": 1024, "huge": 1280}
        vit = timm.create_model(model_name_from_setup(setup), pretrained=not(setup["pretraining_mode"] == "scratch"), embed_layer=MyPatchEmbed)
        vit.head = nn.Sequential()
        vit.pre_logits = torch.nn.Identity()
        lifting_net = nn.Sequential() if setup["lifting_method"] != "mlp" else Seq(*[Rearrange('B M C -> B C M'), Conv1dLayer([depth2featdim[setup["depth"]], 512, 512, 512, depth2featdim[setup["depth"]]], act='relu', norm=True, bias=True), Rearrange('B C M-> B M C')])
        if setup["mvnetwork"] == "vit":
            mvnetwork = MVAgregate(
                vit, agr_type=setup["mv_agr_type"], feat_dim=size2featdim[setup["vit_model_size"]], num_classes=num_classes, lifting_net=lifting_net)

        elif setup["mvnetwork"] == "mvit":
            mvnetwork = FullCrossViewAttention(
                vit, patch_size=setup["patch_size"], num_views=setup["nb_views"], feat_dim=setup["feat_dim"], num_classes=num_classes)
        elif setup["mvnetwork"] == "wvit":
            mvnetwork = WindowCrossViewAttention(vit, patch_size=setup["patch_size"], num_views=setup["nb_views"], num_windows=setup["nb_windows"], feat_dim=setup["feat_dim"], num_classes=num_classes, agr_type=setup["mv_agr_type"])
        elif setup["mvnetwork"] == "nvit":
            vit = ViT(image_size=setup["image_size"], patch_size=setup["patch_size"], num_classes=num_classes, dim=setup["feat_dim"], depth=setup["depth"], heads=setup["mvit_heads"], mlp_dim=setup["mlp_dim"],
                      pool='mean', channels=3, dim_head=64, dropout=setup["mvit_dropout"], emb_dropout=setup["emb_dropout"])
            vit.mlp_head = nn.Sequential()
            mvnetwork = MVAgregate(vit, agr_type=setup["mv_agr_type"], feat_dim=size2featdim[setup["vit_model_size"]], num_classes=num_classes, lifting_net=lifting_net)
        elif setup["mvnetwork"] == "uvit":
            vit = MyViT(model_name_from_setup(setup), pretrained=not(setup["pretraining_mode"] == "scratch"), vit_agr_type=setup["vit_agr_type"])
            vit.vit.head = nn.Sequential()
            vit.vit.pre_logits = torch.nn.Identity()
            mvnetwork = MVAgregate(
                vit, agr_type=setup["mv_agr_type"], feat_dim=size2featdim[setup["vit_model_size"]], num_classes=num_classes, lifting_net=lifting_net)


        print('Using ' + setup["mvnetwork"] + str(setup["depth"]))
    else:
        if not "pls" in setup["run_mode"] and not "point" in setup["run_mode"] :

            if setup["mvnetwork"] == 'fcn':
                mvnetwork = torchvision.models.segmentation.fcn_resnet50(
                    pretrained=not(setup["pretraining_mode"] == "scratch"), num_classes=21)
            elif setup["mvnetwork"] == 'deeplab':
                mvnetwork = torchvision.models.segmentation.deeplabv3_resnet101(
                    pretrained=not(setup["pretraining_mode"] == "scratch"), num_classes=21)
            # elif setup["mvnetwork"] == 'segformer':
            #     mvnetwork = segformer(pretrained=not(
            #         setup["pretraining_mode"] == "scratch"), num_classes=21)
            # elif setup["mvnetwork"] == 'swinsegment':
            #     mvnetwork = swinsegment(pretrained=not(
            #         setup["pretraining_mode"] == "scratch"), num_classes=21, image_size=setup["image_size"])

            else:
                raise ValueError("please pick a correct 2d segmentor")
            mvnetwork = MVPartSegmentation(mvnetwork, num_classes=num_classes, num_parts=num_parts, parallel_head=setup["parallel_head"], balanced_object_loss=setup[
                                           "balanced_object_loss"], balanced_2d_loss_alpha=setup["balanced_2d_loss_alpha"], depth=setup["depth"], total_nb_parts=TOTAL_NB_PARTS)
        elif "pls" in setup["run_mode"]:
            mvnetwork = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=not(setup["pretraining_mode"] == "scratch"), num_classes=21)
            mvnetwork = MVPartSegmentation(mvnetwork, num_classes=1, num_parts=num_classes-1,
                                           parallel_head=False, balanced_object_loss=False, balanced_2d_loss_alpha=0.0, depth=2)
        elif "point" in setup["run_mode"]:
            if setup["mvnetwork"] == "PointNet":
                mvnetwork = PointNet(num_classes=TOTAL_NB_PARTS, in_size=int(
            3 + 3 * int(setup["use_normals"])), segmentation=True).cuda()
            elif setup["mvnetwork"] == "DGCNN":
                mvnetwork = SimpleDGCNN(num_classes=TOTAL_NB_PARTS, in_size=int(3 + 3 * int(setup["use_normals"])), segmentation=True, max_feat=512, k=9).cuda()
            else:
                raise ValueError("please pick a correct 2d segmentor")
    if setup["pretraining_mode"] == "fsl":
        weights = torch.load("./checkpoint/"+setup["mvnetwork"]+"_model.pt", map_location='cpu')
        mvnetwork.load_state_dict(weights["state_dict"])
    return mvnetwork


def get_liftingnet(setup, num_classes, num_parts=1):
    viewembedder = ViewEmbedding(view_embeddgin_type=setup["view_embeddgin_type"],
                                 use_view_info=setup["use_view_info"], embed_dim=setup["view_embedding_dim"])
    # the early features have size 21 , logits have size num_parts+1
    if not "pls" in setup["run_mode"]:
        in_size = num_parts+1 if not setup["use_early_voint_feats"] else 21
    else:
        in_size = num_classes
    out_size = setup["voint_out_size"]
    if setup["lifting_method"] == "mlp" :
        return VointMLP(in_size,out_size, feat_dim=setup["feat_dim"],  aggr=setup["voint_aggr"], use_cls_voint=setup["use_cls_voint"], viewembedder=viewembedder, use_xyz=setup["use_voint_xyz"], voint_depth=setup["voint_depth"])
    elif setup["lifting_method"] == "gcn":
        return VointGCN(in_size,out_size, feat_dim=setup["feat_dim"],  aggr=setup["voint_aggr"], use_cls_voint=setup["use_cls_voint"], viewembedder=viewembedder, leanred_cls_token=setup["leanred_cls_token"], use_xyz=setup["use_voint_xyz"], voint_depth=setup["voint_depth"])
    elif setup["lifting_method"] == "gat":
        return VointGAT(in_size,out_size, feat_dim=setup["feat_dim"],  use_cls_voint=setup["use_cls_voint"], viewembedder=viewembedder, aggr=setup["voint_aggr"], leanred_cls_token=setup["leanred_cls_token"], use_xyz=setup["use_voint_xyz"], voint_depth=setup["voint_depth"])
    elif setup["lifting_method"] == "transformer":
        return VointFormer(in_size, out_size, feat_dim=setup["feat_dim"],  use_cls_voint=setup["use_cls_voint"], viewembedder=viewembedder, aggr=setup["voint_aggr"], leanred_cls_token=setup["leanred_cls_token"], use_xyz=setup["use_voint_xyz"], voint_depth=setup["voint_depth"])


def get_mlp_classifier(setup, num_classes, num_parts=1):
    in_size = setup["voint_out_size"]
    if not "pls" in setup["run_mode"]:
        out_size = num_parts+1
    else : 
        out_size = num_classes 
        setup["parallel_head"] = 0
    if setup["extra_net"] == "PointNet": 
        # extra_net = PointNet(num_parts+1, alignment=True,in_size=num_parts+1)  # TODO
        extra_net = Conv1dLayer(
            [out_size, 512, 1024, 1024,512, out_size], act='relu', norm=True, bias=True)
    elif setup["extra_net"] == "DGCNN":
        # extra_net = SimpleDGCNN(num_parts+1, in_size=num_parts+1)  # TODO
        extra_net = Seq(*[Rearrange('B C N -> B C N 1'), DynEdgeConv2d(out_size, 64, k=5), DynEdgeConv2d(64, 128, k=5),
                          DynEdgeConv2d(128, 512, k=5), DynEdgeConv2d(512, out_size, k=5), Rearrange('B C N 1-> B C N')])
    else :
        extra_net = nn.Sequential()
        
    return PointMLPClassifier(in_size, out_size, use_xyz=setup["use_xyz"], use_global=setup["use_global_feats"], skip=not setup["use_mlp_classifier"], nb_heads=num_classes, parallel_head=setup["parallel_head"], feat_dim=setup["feat_dim"], extra_net=extra_net)
