import os.path

import torch.nn as nn
import torch.nn.functional as F
import time
import MinkowskiEngine as ME
from tqdm import tqdm
from prepare_data.utils import pc_utils
from prepare_data.utils import scannet_utils
import numpy as np
import torch
import evaluate
import random
from examples.common import seed_all
import torch.distributed as dist

def visual_prediction(coords_scannet, labels_scannet, predictions, i, object_visual, config):
    random_samples = torch.randperm(coords_scannet.size()[0])
    sample_points = 100000

    import cv2

    predictions = predictions[random_samples[:sample_points]]
    coords_scannet = coords_scannet[random_samples[:sample_points]]
    labels_scannet = labels_scannet[random_samples[:sample_points]]

    heatmap = np.uint8(255*predictions.detach().cpu().numpy())
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)

    color_template = scannet_utils.create_color_palette()

    # import pdb
    # pdb.set_trace()


    # r_output = 0.4 * predictions
    # g_output = predictions
    # b_output = 1 - 0.5 * predictions
    # color_output = torch.cat((r_output.unsqueeze(1), g_output.unsqueeze(1), b_output.unsqueeze(1)), dim=1) * 255

    color_output = torch.from_numpy(heatmap[:, 0, :])

    r = 0 * predictions
    g = 0 * predictions
    b = 0 * predictions

    index = labels_scannet[:, 1].long() == scannet_utils.CLASS_LABELS.index(object_visual)

    if index.sum() == 0: return

    # label_a_remapping[label_a == 0] = torch.tensor(color_template[5]).long()
    # label_a_remapping[label_a == 1] = torch.tensor(color_template[10]).long()
    # label_a_remapping[label_a == 2] = torch.tensor(color_template[6]).long()
    # label_a_remapping[label_a == 3] = torch.tensor(color_template[7]).long()

    color_index = [5, 7, 4, 34, 36, 8, 16, 14, 10, 6, 33]

    # color_index = [5, 10, 6, 7]

    label_remapping = torch.zeros(labels_scannet.size()[0], 3).long()
    label_remapping[:] = torch.tensor(color_template[1]).long()


    label_remapping[index] = torch.tensor(color_template[color_index[config.model_list.index(object_visual)]]).long()
    color = label_remapping

    # import pdb
    # pdb.set_trace()

    if not os.path.exists('visual_result_scannet_%s' % object_visual): os.makedirs('visual_result_scannet_%s' % object_visual)

    pc_utils.write_ply_rgb(coords_scannet[:, 1:], color_output, 'visual_result_scannet_%s/%s_%s_%s_output.ply' % (object_visual, config.device, i, object_visual), text=True)
    pc_utils.write_ply_rgb(coords_scannet[:, 1:], color, 'visual_result_scannet_%s/%s_%s_%s.ply' % (object_visual, config.device, i, object_visual), text=True)

def output_examples(coords):
    example1 = coords[coords[:, 0] == 0][:, 1:]
    example2 = coords[coords[:, 0] == 1][:, 1:]
    example3 = coords[coords[:, 0] == 2][:, 1:]
    example4 = coords[coords[:, 0] == 3][:, 1:]

    # example1 = pc_utils.rotation(example1)
    # example2 = pc_utils.rotation(example2)
    # example3 = pc_utils.rotation(example3)
    # example4 = pc_utils.rotation(example4)

    pc_utils.write_ply(example1, "1_modelnet_rotation.ply", text=True)
    pc_utils.write_ply(example2, "2_modelnet_rotation.ply", text=True)
    pc_utils.write_ply(example3, "3_modelnet_rotation.ply", text=True)
    pc_utils.write_ply(example4, "4_modelnet_rotation.ply", text=True)


def output_instance(batch):
    output_path = "/userhome/cs/crnsmile/project/unsupervised_segmentation/instances"
    coords, feats, labels, scene_name = batch['coords'], batch['feats'], batch['labels'], batch['scene_name']
    instance_num = int(labels[:, 0].max())
    for i in range(instance_num):
        index = labels[:, 0] == i
        coord = coords[index]
        feat = feats[index]
        label = labels[index]
        if index.sum() <= 0: continue
        output_name = os.path.join(output_path, scene_name + '_%s_%s.ply' % (int(label[:, 1].min()), i))
        pc_utils.write_ply_rgb(coord, feat, output_name, text=True)

def get_in_field(coords, feats, config):
    in_field = ME.TensorField(coordinates=coords, features=feats,
                              # coordinate_map_key=A.coordiante_map_key, coordinate_manager=A.coordinate_manager,
                              quantization_mode=ME.SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE,
                              minkowski_algorithm=ME.MinkowskiAlgorithm.SPEED_OPTIMIZED,
                              # minkowski_algorithm=ME.MinkowskiAlgorithm.MEMORY_EFFICIENT,
                              device=config.device,
                              ).float()
    return in_field


def place_model_in_scan(model, scan, label, p_label, mask, mask_model, mode):
    # get placement offset
    min_posi_scan, max_posi_scan = 0, 0
    min_posi_model, max_posi_model = 0, 0
    try:
        min_posi_scan, max_posi_scan = scan.min(0).values, scan.max(0).values
        min_posi_model, max_posi_model = model.min(0).values, model.max(0).values
    except Exception as e:
        return scan, label, mask
    lower = min_posi_scan - min_posi_model
    upper = max_posi_scan - max_posi_model
    x = random.randint(int(lower[0]), max(int(upper[0]), int(lower[0]) + 1))
    y = random.randint(int(lower[1]), max(int(upper[1]), int(lower[1]) + 1))
    # z = random.randint(int(lower[2]),int(upper[2]))
    z = lower[2]

    offset = torch.tensor([x, y, z])

    # eliminate
    model += offset

    if mode == 'non_overlap':
        min_posi_eliminate, max_posi_eliminate = model.min(0).values, model.max(0).values
        index = (scan[:, 0] >= min_posi_eliminate[0]) & \
                (scan[:, 1] >= min_posi_eliminate[1]) & \
                (scan[:, 0] <= max_posi_eliminate[0]) & \
                (scan[:, 1] <= max_posi_eliminate[1])
        # (scan[:, 2] >= min_posi_eliminate[2]) & \
        # (scan[:, 2] <= max_posi_eliminate[2])
        # index = ~index
        # scan = scan[index]
        # label = label[index]
        mask[index] = 0

    scan_placement = torch.cat((scan, model), dim=0)
    label = torch.cat((label, torch.ones(model.size()[0]) * p_label), dim=0)
    mask = torch.cat((mask, mask_model), dim=0)
    # mask = torch.cat((mask, torch.ones(model.size()[0])), dim=0)

    return scan_placement, label, mask
'''
def placement_zeroshot(modelnet_data, scannet_data, config, idx):
    #input: modelnet data (Bm*N, 4), scannet data (Bs*N, 4),
    #output: scannet data with placement (Bs*N, 4), labels (Bs*N,)
    #extended output: scannet data with placement (Bs*N, 4), labels (Bs*N,), placed model (2048, 4)

    coords_scannet = scannet_data['coords']
    labels_scannet = scannet_data['labels']

    model_list = config.model_list
    # chairs = modelnet_data['chair']
    # beds = modelnet_data['bed']
    # sofas = modelnet_data['sofa']
    # tables = modelnet_data['table']

    nega_models = modelnet_data['nega_data']

    scan_placements = []
    labels = []

    for i in range(config.batch_size_scannet):
        # get scan and model
        index = coords_scannet[:, 0] == i
        scan = coords_scannet[index][:, 1:]
        # -100 for unknown label
        label = torch.ones(scan.size()[0]) * -100
        scan = torch.cat((torch.ones(scan.size()[0], 1) * len(config.model_list) * 2 * config.batch_size_scannet + i, scan), dim=1)

        scan_placements.append(scan)
        labels.append(label)

        # print(len(config.model_list) * 2 * config.batch_size_scannet + i)

        # 'chair', 'desk', 'sofa', 'table'
        for ii, item in enumerate(config.model_list):
            posi_models = modelnet_data[item]

            # place posi_model
            model = posi_models[:, 1:]
            model = pc_utils.random_rotation(model)
            model = torch.cat((torch.ones(model.size()[0], 1) * i * len(config.model_list) * 2 + ii, model), dim=1)
            label = torch.ones(model.size()[0]) * ii
            scan_placements.append(model)
            labels.append(label)

            # print(i * len(config.model_list) * 2 + ii)

            # place negative model
            model = nega_models[:, 1:]
            model = pc_utils.random_rotation(model)
            model = torch.cat((torch.ones(model.size()[0], 1) * i * len(config.model_list) * 2 + ii + len(config.model_list), model), dim=1)
            label = torch.ones(model.size()[0]) * len(model_list)
            # print(i * len(config.model_list) * 2 + ii + len(config.model_list))
            # batched_coordinates
            scan_placements.append(model)
            labels.append(label)

    # import pdb
    # pdb.set_trace()
    # import pdb
    # pdb.set_trace()


    if len(scan_placements) == 0:
        return None, None, None, None, None

    scan_placements = torch.cat(scan_placements, dim=0)
    labels = torch.cat(labels, dim=0)

    feats = torch.ones(scan_placements.size()[0], 1)
    # import pdb
    # pdb.set_trace()

    coords_model = []
    feats_model = []

    return scan_placements, feats, labels, coords_model, feats_model

        # pc_utils.write_ply(model, str(i)+"model.ply", text=True)
        # pc_utils.write_ply(scan, str(i)+"scan.ply", text=True)
        # pc_utils.write_ply(scan_placement, str(i)+"scan_placement.ply", text=True)

        # scan = torch.cat((torch.ones(scan.size()[0], 1) * i, scan), dim=1)
        # model = torch.cat((torch.ones(model.size()[0], 1) * i, model), dim=1)
    # import pdb
    # pdb.set_trace()

    # pass
'''
def placement(modelnet_data, scannet_data, config, idx):
    #input: modelnet data (Bm*N, 4), scannet data (Bs*N, 4),
    #output: scannet data with placement (Bs*N, 4), labels (Bs*N,)
    #extended output: scannet data with placement (Bs*N, 4), labels (Bs*N,), placed model (2048, 4)

    coords_scannet = scannet_data['coords']
    labels_scannet = torch.ones(coords_scannet.size()[0]) * -100

    for i, item in enumerate(config.model_list):
        index = scannet_data['labels'][:, 1] == scannet_utils.CLASS_LABELS.index(item)
        labels_scannet[index] = i

    # chairs = modelnet_data['chair']
    # beds = modelnet_data['bed']
    # sofas = modelnet_data['sofa']
    # tables = modelnet_data['table']

    nega_models = modelnet_data['nega_data']

    scan_placements_a = []
    labels_a = []
    masks_a = []

    scan_placements_b = []
    labels_b = []
    masks_b = []

    for i in range(config.batch_size_scannet):
        # get scan and model
        index = coords_scannet[:, 0] == i
        scan_o = coords_scannet[index][:, 1:]
        label_o = labels_scannet[index]

        # -100 for unknown label
        label_a = torch.ones(scan_o.size()[0]) * -100
        mask_a = torch.ones(scan_o.size()[0])
        scan_a = pc_utils.random_rotation(scan_o)

        label_b = torch.ones(scan_o.size()[0]) * -100
        mask_b = torch.ones(scan_o.size()[0])
        scan_b = pc_utils.random_rotation(scan_o)

        # 'chair', 'bookshelf', 'sofa', 'table'
        re_order = {}
        for ii, item in enumerate(config.model_list):
            re_order[item] = list(range(config.batch_size_modelnet))
            random.shuffle(re_order[item])
        re_order['nega'] = list(range(config.batch_size_modelnet))
        random.shuffle(re_order['nega'])

        all_list = config.model_list + ['nega']
        # all_list = config.model_list
        random.shuffle(all_list)

        # pc_utils.write_ply(scan, "%s_scan.ply" % -1, text=True)

        # 'chair', 'bookshelf', 'sofa', 'table' , 'nega'
        ran_num = random.random()
        for ii, item in enumerate(all_list):
            # place posi_model model into scan1
            # index = random.randint(0, config.batch_size_modelnet - 1)
            index = re_order[item][i]

            if item == 'nega':
                index = nega_models[:, 0] == index
                model_o = nega_models[index][:, 1:]
                if model_o.size()[0] == 0: continue

                # part1, part2 = pc_utils.random_partition(pc_utils.random_rotation(model_o))
                # model = pc_utils.random_rotation(model_o)
                # mask_model = torch.ones(model.size()[0])
                # scan_b, label_b, mask_b = place_model_in_scan(model, scan_b, label_b, len(config.model_list), mask_b, mask_model, 'overlap')

                # model, mask_model = pc_utils.random_partition(pc_utils.random_rotation(model_o))
                # pc_utils.write_ply(model, "%s_model.ply" % item, text=True)
                # model = pc_utils.random_rotation(model_o)
                # scan_a, label_a, mask_a = place_model_in_scan(model, scan_a, label_a, len(config.model_list), mask_a, mask_model, 'non_overlap')

                if ran_num >= 0.5:
                    model, mask_model = pc_utils.random_partition(pc_utils.random_rotation(model_o))
                    scan_a, label_a, mask_a = place_model_in_scan(model, scan_a, label_a, len(config.model_list_zero_shot), mask_a, mask_model, 'non_overlap')
                else:
                    model, mask_model = pc_utils.random_rotation(model_o), torch.ones(model_o.size()[0])
                    scan_a, label_a, mask_a = place_model_in_scan(model, scan_a, label_a, len(config.model_list_zero_shot), mask_a, mask_model, 'overlap')
                # pc_utils.write_ply(scan, "%s_scan.ply" % item, text=True)
            else:
                posi_models = modelnet_data[item]
                index = posi_models[:, 0] == index
                model_o = posi_models[index][:, 1:]
                if model_o.size()[0] == 0: continue

                # part1, part2 = pc_utils.random_partition(pc_utils.random_rotation(model_o))
                # model = pc_utils.random_rotation(model_o)
                # mask_model = torch.ones(model.size()[0])
                # scan_b, label_b, mask_b = place_model_in_scan(model, scan_b, label_b, config.model_list.index(item), mask_b, mask_model, 'overlap')

                # model, mask_model = pc_utils.random_partition(pc_utils.random_rotation(model_o))
                # pc_utils.write_ply(model, "%s_model.ply" % item, text=True)
                # model = pc_utils.random_rotation(model_o)
                # scan_a, label_a, mask_a = place_model_in_scan(model, scan_a, label_a, config.model_list.index(item), mask_a, mask_model, 'non_overlap')

                if ran_num >= 0.5:
                    model, mask_model = pc_utils.random_partition(pc_utils.random_rotation(model_o))
                    scan_a, label_a, mask_a = place_model_in_scan(model, scan_a, label_a, config.model_list.index(item), mask_a, mask_model, 'non_overlap')
                else:
                    model, mask_model = pc_utils.random_rotation(model_o), torch.ones(model_o.size()[0])
                    scan_a, label_a, mask_a = place_model_in_scan(model, scan_a, label_a, config.model_list.index(item), mask_a, mask_model, 'overlap')

        # scan_index = label_a == -100
        # valid_index_a = mask_a == 1
        #
        # scan_b_scanpart = scan_b[scan_index]
        # scan_b_positive = scan_b[~scan_index & valid_index_a]
        # scan_b_negative = scan_b[~scan_index & ~valid_index_a]
        # scan_b = torch.cat((scan_b_scanpart, scan_b_positive, scan_b_negative), dim=0)
        #
        # label_b_scanpart = label_b[scan_index]
        # label_b_positive = label_b[~scan_index & valid_index_a]
        # label_b_negative = label_b[~scan_index & ~valid_index_a]
        # label_b = torch.cat((label_b_scanpart, label_b_positive, label_b_negative), dim=0)

        # scan_placements_b.append(scan_b)
        # labels_b.append(label_b)
        # masks_b.append(mask_b)

        scan_a = scan_a[mask_a == 1]
        label_a = label_a[mask_a == 1]

        scan_placements_a.append(scan_a)
        labels_a.append(label_a)
        masks_a.append(mask_a)

        '''
        if (label_o == 0).sum() > 0:
            # pc_utils.write_ply(scan_o, "%s_scan_o.ply" % idx, text=True)
            # pc_utils.write_ply(scan_a[mask_a == 1][:, 1:], "scan_a.ply", text=True)
            print('scan_a ', scan_a.size())
            print('label_a ', label_a.size())
            pc_utils.write_ply(scan_a[label_a == 0], "%s_a_toilet.ply" % idx, text=True)
            pc_utils.write_ply(scan_o[label_o == 0], "%s_o_toilet.ply" % idx, text=True)

        if (label_o == 1).sum() > 0:
            pc_utils.write_ply(scan_a[label_a == 1], "%s_a_door.ply" % idx, text=True)
            pc_utils.write_ply(scan_o[label_o == 1], "%s_o_door.ply" % idx, text=True)

        if (label_o == 2).sum() > 0:
            pc_utils.write_ply(scan_a[label_a == 2], "%s_a_curtain.ply" % idx, text=True)
            pc_utils.write_ply(scan_o[label_o == 2], "%s_o_curtain.ply" % idx, text=True)
        '''

        # pc_utils.write_ply(scan_b[label_o == 1][:, 1:], "%s_b_bookshelf.ply" % i, text=True)
        # import pdb
        # pdb.set_trace()
        # a = input()
        # batched_coordinates

    if len(scan_placements_a) == 0:
        return None, None, None, None, None

    scan_placements = ME.utils.batched_coordinates(scan_placements_a, dtype=torch.float32)
    labels = torch.cat(labels_a, dim=0)
    masks = torch.cat(masks_a, dim=0)

    # scan_placements = ME.utils.batched_coordinates(scan_placements_b, dtype=torch.float32)
    # labels = torch.cat(labels_b, dim=0)
    # masks = torch.cat(masks_b, dim=0)
    #
    # scan_placements = ME.utils.batched_coordinates(scan_placements_a + scan_placements_b, dtype=torch.float32)
    # labels = torch.cat(labels_a + labels_b, dim=0)
    # masks = torch.cat(masks_a + masks_b, dim=0)

    feats = torch.ones(scan_placements.size()[0], 1)

    coords_model = []
    feats_model = []

    return scan_placements, feats, labels, masks, coords_model, feats_model

        # pc_utils.write_ply(model, str(i)+"model.ply", text=True)
        # pc_utils.write_ply(scan, str(i)+"scan.ply", text=True)
        # pc_utils.write_ply(scan_placement, str(i)+"scan_placement.ply", text=True)

        # scan = torch.cat((torch.ones(scan.size()[0], 1) * i, scan), dim=1)
        # model = torch.cat((torch.ones(model.size()[0], 1) * i, model), dim=1)
    # import pdb
    # pdb.set_trace()

    # pass

'''
def placement_b(modelnet_data, scannet_data, config, idx):
    #input: modelnet data (Bm*N, 4), scannet data (Bs*N, 4),
    #output: scannet data with placement (Bs*N, 4), labels (Bs*N,)
    #extended output: scannet data with placement (Bs*N, 4), labels (Bs*N,), placed model (2048, 4)

    coords_scannet = scannet_data['coords']
    labels_scannet = scannet_data['labels']

    model_list = config.model_list
    # chairs = modelnet_data['chair']
    # beds = modelnet_data['bed']
    # sofas = modelnet_data['sofa']
    # tables = modelnet_data['table']

    nega_models = modelnet_data['nega_data']

    scan_placements = []
    labels = []

    for i in range(config.batch_size_scannet):
        # get scan and model
        index = coords_scannet[:, 0] == i
        scan = coords_scannet[index][:, 1:]
        # -100 for unknown label
        label = torch.ones(scan.size()[0]) * -100

        # filter those scenes without 'chair'
        check_type = 'bookshelf'
        labels_tem = labels_scannet[index][:, 1].long()
        labels_trans = torch.ones(labels_tem.size()).long()
        labels_trans[labels_tem == scannet_utils.CLASS_LABELS.index(check_type)] = 0
        if labels_trans.size()[0] == labels_trans.sum():
            continue
        positive = scan[labels_tem == scannet_utils.CLASS_LABELS.index(check_type)]

        # 'chair', 'desk', 'sofa', 'table'
        posi_models = modelnet_data[model_list[i % len(model_list)]]


        # place posi_model model into scan1
        index = random.randint(0, config.batch_size_modelnet - 1)
        index = posi_models[:, 0] == index
        model = posi_models[index][:, 1:]
        model = pc_utils.random_rotation(model)
        scan, label = place_model_in_scan(model, scan, label, i % len(model_list))
        if i == model_list.index(check_type):
            pc_utils.write_ply(model, "%s_model_rotation.ply" % idx, text=True)
            pc_utils.write_ply(scan, "%s_scan_rotation.ply" % idx, text=True)
            pc_utils.write_ply(positive, "%s_positive_rotation.ply" % idx, text=True)

        # place negative model into scan
        index = random.randint(0, config.batch_size_modelnet - 1)
        index = nega_models[:, 0] == index
        model = nega_models[index][:, 1:]
        model = pc_utils.random_rotation(model)
        scan, label = place_model_in_scan(model, scan, label, len(model_list))

        # batched_coordinates
        scan = torch.cat((torch.ones(scan.size()[0], 1) * i, scan), dim=1)

        scan_placements.append(scan)
        labels.append(label)

    # import pdb
    # pdb.set_trace()

    if len(scan_placements) == 0:
        return None, None, None, None, None

    scan_placements = torch.cat(scan_placements, dim=0)
    labels = torch.cat(labels, dim=0)

    feats = torch.ones(scan_placements.size()[0], 1)
    # import pdb
    # pdb.set_trace()

    coords_model = []
    feats_model = []

    return scan_placements, feats, labels, coords_model, feats_model

        # pc_utils.write_ply(model, str(i)+"model.ply", text=True)
        # pc_utils.write_ply(scan, str(i)+"scan.ply", text=True)
        # pc_utils.write_ply(scan_placement, str(i)+"scan_placement.ply", text=True)

        # scan = torch.cat((torch.ones(scan.size()[0], 1) * i, scan), dim=1)
        # model = torch.cat((torch.ones(model.size()[0], 1) * i, model), dim=1)
    # import pdb
    # pdb.set_trace()
    # pass

def get_instances(predictions, labels_scannet, coords_scannet, ofield_scannet, scannet_data, config):
    # return instances_coords (N, 3+1)
    # instances_feats (N, 96)
    # instances_labels (N)

    instances_coords = []
    instances_feats = []
    instances_labels = []
    semantics_labels = []

    instances_model_coords = []
    instances_model_feats = []
    instances_model_labels = []
    semantics_model_labels = []

    # croping models from replaced scenes
    origin_index = coords_scannet[:, 0] < config.batch_size_scannet
    origin_feats = predictions[origin_index]
    # origin_feats = ofield_scannet[origin_index]
    origin_coords = coords_scannet[origin_index]
    origin_labels = labels_scannet[origin_index]
    scans_ids = list(np.unique(origin_coords[:, 0].numpy()))

    for id in scans_ids:
        index = origin_coords[:, 0] == id
        coords = origin_coords[index]
        feats = origin_feats[index]
        label = origin_labels[index]
        for label_id in range(len(config.model_list) + 1):
            index = label == label_id
            inst_feats = feats[index]
            inst_labels = label_id
            inst_coords = coords[index]
            seman_label = torch.ones(inst_coords.size()[0]) * label_id
            if inst_coords.size()[0] == 0: continue

            instances_model_coords.append(inst_coords[:, 1:])
            instances_model_feats.append(inst_feats)
            instances_model_labels.append(inst_labels)
            semantics_model_labels.append(seman_label)

    # croping instances from unlabeled scenes
    origin_index = coords_scannet[:, 0] >= config.batch_size_scannet
    origin_feats = predictions[origin_index]
    # origin_feats = ofield_scannet[origin_index]
    origin_coords = coords_scannet[origin_index]
    origin_labels = scannet_data['labels']

    # sample scans
    scans_ids = list(np.unique(origin_coords[:, 0].numpy()))

    for id in scans_ids:
        index = origin_coords[:, 0] == id
        coords = origin_coords[index]
        feats = origin_feats[index]
        label = origin_labels[index]

        # sample positive instances
        posi_index = label[:, 1] == scannet_utils.CLASS_LABELS.index(config.model_list[0])
        for i, item in enumerate(config.model_list):
            index = label[:, 1] == scannet_utils.CLASS_LABELS.index(item)
            posi_index = posi_index | index
            if index.sum() == 0: continue
            tem_feats = feats[index]
            tem_labels = label[index]
            tem_coords = coords[index]

            instances_ids = list(np.unique(tem_labels[:, 0].numpy()))
            for inst_id in instances_ids:
                index = tem_labels[:, 0] == inst_id
                inst_feats = tem_feats[index]
                inst_labels = i
                inst_coords = tem_coords[index]
                sema_label = torch.ones(inst_coords.size()[0]) * inst_labels
                # pc_utils.write_ply(instance_coords[:, 1:], "%s_%s.ply" % (item, inst_id), text=True)
                if inst_coords.size()[0] == 0: continue

                instances_coords.append(inst_coords[:, 1:])
                instances_feats.append(inst_feats)
                instances_labels.append(inst_labels)
                semantics_labels.append(sema_label)

        # sample positive instances
        nega_index = ~posi_index
        tem_feats = feats[nega_index]
        tem_labels = label[nega_index]
        tem_coords = coords[nega_index]

        instances_ids = list(np.unique(tem_labels[:, 0].numpy()))
        instances_ids = random.sample(instances_ids, min(5, len(instances_ids)))
        for inst_id in instances_ids:
            index = tem_labels[:, 0] == inst_id
            inst_feats = tem_feats[index]
            inst_labels = len(config.model_list)
            inst_coords = tem_coords[index]
            sema_label = torch.ones(inst_coords.size()[0]) * inst_labels
            # pc_utils.write_ply(instance_coords[:, 1:], "%s_%s.ply" % (item, inst_id), text=True)
            if inst_coords.size()[0] == 0: continue

            instances_coords.append(inst_coords[:, 1:])
            instances_feats.append(inst_feats)
            instances_labels.append(inst_labels)
            semantics_labels.append(sema_label)

    # random sampling
    sample_index = list(random.sample(range(len(instances_coords)), min(20, len(instances_coords))))
    instances_coords = [instances_coords[id] for id in sample_index]
    instances_feats = [instances_feats[id] for id in sample_index]
    instances_labels = [instances_labels[id] for id in sample_index]
    semantics_labels = [semantics_labels[id] for id in sample_index]

    # model_list = config.model_list + ['nega']
    # tem_coords = instances_coords + instances_model_coords
    # tem_labels = instances_labels + instances_model_labels
    # for i, coord in enumerate(tem_coords):
    #     print(tem_labels[i])
    #     pc_utils.write_ply(coord, "%s_%s_model.ply" %(i, model_list[tem_labels[i]]), text=True)

    # package
    coords = ME.utils.batched_coordinates(instances_coords + instances_model_coords, dtype=torch.float32)
    feats = torch.cat(instances_feats + instances_model_feats, dim=0)
    instances_labels = torch.tensor(instances_labels + instances_model_labels)
    semantics_labels = torch.cat(semantics_labels + semantics_model_labels, dim=0)

    return coords, feats, instances_labels, semantics_labels
'''

def get_instances_adaver(predictions, labels_scannet, coords_scannet, ofield_scannet, scannet_data, config):
    # return instances_coords (N, 3+1)
    # instances_feats (N, 96)
    # instances_labels (N)

    instances_coords = []
    instances_feats = []
    instances_labels = []
    semantics_labels = []

    instances_model_coords = []
    instances_model_feats = []
    instances_model_labels = []
    semantics_model_labels = []

    # croping models from replaced scenes
    origin_index = coords_scannet[:, 0] < config.batch_size_scannet
    origin_feats = predictions[origin_index]
    # origin_feats = ofield_scannet[origin_index]
    origin_coords = coords_scannet[origin_index]
    origin_labels = labels_scannet[origin_index]
    scans_ids = list(np.unique(origin_coords[:, 0].numpy()))

    for id in scans_ids:
        index = origin_coords[:, 0] == id
        coords = origin_coords[index]
        feats = origin_feats[index]
        label = origin_labels[index]
        for label_id in range(len(config.model_list) + 1):
            index = label == label_id
            inst_coords = coords[index]
            inst_labels = 1

            inst_feats = torch.ones(inst_coords.size()[0]).to(config.device)
            seman_label = torch.ones(inst_coords.size()[0]) * label_id
            if inst_coords.size()[0] == 0: continue

            instances_model_coords.append(inst_coords[:, 1:])
            instances_model_feats.append(inst_feats)
            instances_model_labels.append(inst_labels)
            semantics_model_labels.append(seman_label)

    # croping instances from unlabeled scenes
    origin_index = coords_scannet[:, 0] >= config.batch_size_scannet
    origin_feats = torch.softmax(predictions[origin_index], dim=1)
    # origin_feats = ofield_scannet[origin_index]
    origin_coords = coords_scannet[origin_index]
    origin_labels = scannet_data['labels']

    # sample scans
    scans_ids = list(np.unique(origin_coords[:, 0].numpy()))

    for id in scans_ids:
        index = origin_coords[:, 0] == id
        coords = origin_coords[index]
        feats = origin_feats[index]
        label = origin_labels[index]

        # sample positive instances
        posi_index = label[:, 1] == scannet_utils.CLASS_LABELS.index(config.model_list[0])
        for i, item in enumerate(config.model_list):
            index = label[:, 1] == scannet_utils.CLASS_LABELS.index(item)
            posi_index = posi_index | index
            if index.sum() == 0: continue
            tem_feats = feats[index]
            tem_labels = label[index]
            tem_coords = coords[index]

            instances_ids = list(np.unique(tem_labels[:, 0].numpy()))
            for inst_id in instances_ids:
                index = tem_labels[:, 0] == inst_id
                inst_feats = tem_feats[index][:, i]
                inst_labels = 0
                inst_coords = tem_coords[index]
                sema_label = torch.ones(inst_coords.size()[0]) * inst_labels
                # pc_utils.write_ply(instance_coords[:, 1:], "%s_%s.ply" % (item, inst_id), text=True)
                if inst_coords.size()[0] == 0: continue

                instances_coords.append(inst_coords[:, 1:])
                instances_feats.append(inst_feats)
                instances_labels.append(inst_labels)
                semantics_labels.append(sema_label)

        # sample positive instances
        nega_index = ~posi_index
        tem_feats = feats[nega_index]
        tem_labels = label[nega_index]
        tem_coords = coords[nega_index]

        instances_ids = list(np.unique(tem_labels[:, 0].numpy()))
        instances_ids = random.sample(instances_ids, min(5, len(instances_ids)))
        for inst_id in instances_ids:
            index = tem_labels[:, 0] == inst_id
            inst_feats = tem_feats[index][:, len(config.model_list)]
            inst_labels = 0
            inst_coords = tem_coords[index]
            sema_label = torch.ones(inst_coords.size()[0]) * inst_labels
            # pc_utils.write_ply(instance_coords[:, 1:], "%s_%s.ply" % (item, inst_id), text=True)
            if inst_coords.size()[0] == 0: continue

            instances_coords.append(inst_coords[:, 1:])
            instances_feats.append(inst_feats)
            instances_labels.append(inst_labels)
            semantics_labels.append(sema_label)

    # random sampling
    sample_index = list(random.sample(range(len(instances_coords)), min(20, len(instances_coords))))
    instances_coords = [instances_coords[id] for id in sample_index]
    instances_feats = [instances_feats[id] for id in sample_index]
    instances_labels = [instances_labels[id] for id in sample_index]
    semantics_labels = [semantics_labels[id] for id in sample_index]

    # model_list = config.model_list + ['nega']
    # tem_coords = instances_coords + instances_model_coords
    # tem_labels = instances_labels + instances_model_labels
    # for i, coord in enumerate(tem_coords):
    #     print(tem_labels[i])
    #     pc_utils.write_ply(coord, "%s_%s_model.ply" %(i, model_list[tem_labels[i]]), text=True)

    # package
    coords = ME.utils.batched_coordinates(instances_coords + instances_model_coords, dtype=torch.float32)
    feats = torch.cat(instances_feats + instances_model_feats, dim=0).unsqueeze(1)
    instances_labels = torch.tensor(instances_labels + instances_model_labels)
    semantics_labels = torch.cat(semantics_labels + semantics_model_labels, dim=0)
    return coords, feats, instances_labels, semantics_labels

def generative_adversarial(netD, optimizerD, real_models, fake_models, labels_scannet, epoch, config):
    real_label = 1
    fake_label = 0
    # Format batch
    real_feature = real_models
    fake_feature = fake_models
    optimizerD.zero_grad()
    # random_samples = torch.randperm(real_feature.size()[0])
    # scan_feat = real_feature[random_samples[:real_feature.size()[0]]]

    # Forward pass real batch through D
    output = netD(real_feature).view(-1)
    # Calculate loss on all-real batch
    label = torch.ones(output.size()[0]) * real_label

    # import pdb
    # pdb.set_trace()

    errD_real = F.binary_cross_entropy(output.view(-1), label.to(config.device).float())

    # Calculate gradients for D in backward pass
    errD_real.backward(retain_graph=True)
    D_x = output.mean().item()

    ## Train with all-fake batch
    # Generate batch of latent vectors
    # Generate fake image batch with G
    fake = fake_feature
    # Classify all fake batch with D
    output = netD(fake.detach()).view(-1)
    label = torch.ones(output.size()[0]) * fake_label
    # Calculate D's loss on the all-fake batch
    errD_fake = F.binary_cross_entropy(output.view(-1), label.to(config.device).float())
    # Calculate the gradients for this batch, accumulated (summed) with previous gradients
    errD_fake.backward(retain_graph=True)
    D_G_z1 = output.mean().item()
    # Compute error of D as sum over the fake and the real batches
    errD = errD_real + errD_fake
    # Update D
    optimizerD.step()

    ############################
    # (2) Update G network: maximize log(D(G(z)))
    # netG.zero_grad()
    ###########################
    # label.fill_(real_label)  # fake labels are real for generator cost
    # Since we just updated D, perform another forward pass of all-fake batch through D
    output = netD(fake).view(-1)
    label = torch.ones(output.size()[0]) * real_label
    # Calculate G's loss based on this output
    errG = F.binary_cross_entropy(output.view(-1), label.to(config.device).float())
    # Calculate gradients for G
    errG.backward(retain_graph=True)
    D_G_z2 = output.mean().item()
    # Update G
    # optimizerG.step()


    torch.cuda.empty_cache()
    # if epoch % 5 == 0:
    #     print('[%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
    #           % (epoch, errD.item(), 0, D_x, D_G_z1, 0))
    if epoch % 5 == 0:
        print('[%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
              % (epoch, errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

def train(dataloaders, models, optimizer, optimizer_G, criterion, criterion_generator, config):
    # if config.device == 0:
    # val(dataloaders, models, optimizer, criterion, config)
    val_iou(dataloaders, models, optimizer, criterion, config)
    model_modelnet, model_scannet, model_URS, generator = models
    scannet_dataloader, modelnet_dataloader = dataloaders
    modelnet_iter = iter(modelnet_dataloader)

    # seed_all(config.seed)
    model_scannet.train()
    # model_modelnet.train()
    # model_URS.train()
    tt = 0


    word2vector = np.load(config.word2vec, allow_pickle=True).item()
    word_embeddings = []
    for item in config.model_list_zero_shot:
        word_embeddings.append(torch.from_numpy(word2vector[item]).to(config.device).unsqueeze(0))
    word_embeddings = torch.cat(word_embeddings, dim=0)

    # word2vector = np.load(config.word2vec, allow_pickle=True).item()
    # word_embeddings = []
    # for item in config.model_list_zero_shot:
    #     word_embeddings.append(torch.from_numpy(word2vector[item]).to(config.device).unsqueeze(0))
    # word_embeddings = torch.cat(word_embeddings, dim=0)
    upper_bound = 1200000

    for epoch in range(config.epochs):
        start = time.time()
        train_loss=0
        pbar = tqdm(total=len(scannet_dataloader['train']) * config.batch_size_scannet*config.world_size)
        for i, scannet_data in enumerate(scannet_dataloader['train']):
            optimizer.zero_grad()
            try: modelnet_data = modelnet_iter.next()
            except:
                del modelnet_iter
                modelnet_iter = iter(modelnet_dataloader)
                modelnet_data = modelnet_iter.next()

            # coords_modelnet, feats_modelnet, labels_modelnet = modelnet_data['coords'], modelnet_data['feats'], modelnet_data['labels']
            # in_field_modelnet = get_in_field(coords_modelnet, feats_modelnet, config)
            # sinput_modelnet = in_field_modelnet
            # soutput_modelnet = model_modelnet(sinput_modelnet)
            soutput_modelnet = 0

            # coords_scannet_origin, feats_scannet_origin, labels_scannet_origin = scannet_data['coords'], scannet_data['feats'], scannet_data['labels']
            coords_scannet, feats_scannet, labels_scannet, masks, coords_model, feats_model = placement(modelnet_data, scannet_data, config, i)
            if coords_scannet == None:
                continue

            if coords_scannet.size()[0] >= upper_bound:
                # sample_index = list(random.sample(range(coords_scannet.size()[0]), upper_bound))
                coords_scannet = coords_scannet[-upper_bound:]
                feats_scannet = feats_scannet[-upper_bound:]
                labels_scannet = labels_scannet[-upper_bound:]
                masks = masks[-upper_bound:]
            # print(coords_scannet.size())

            # labels: instance_id, label_id
            in_field_scannet = get_in_field(coords_scannet, feats_scannet, config)
            labels = labels_scannet.long().to(config.device)
            # print('coords_scannet: ', coords_scannet.size())
            predictions, loss = model_scannet(coords_scannet, in_field_scannet, labels, word_embeddings, 'train')
            # print('tot predictions: ', predictions.size())


            train_loss += loss.item()
            loss.backward()
            optimizer.step()
            pbar.update(config.batch_size_scannet*config.world_size)
            torch.cuda.empty_cache()

        pbar.close()
        print('epoch： %s Train loss: %s time: %s' % (epoch,
        train_loss / (len(scannet_dataloader['val']) + 1), time.time() - start))
        if epoch % 1 == 0:
            # val(dataloaders, models, optimizer, criterion, config)
            val_iou(dataloaders, models, optimizer, criterion, config)
        torch.cuda.empty_cache()

best_accuracy = {'average': 0, 'chair': 0, 'sofa': 0, 'table': 0, 'bed': 0, 'bookshelf': 0, 'desk': 0}
def val(dataloaders, models, optimizer, criterion, config):
    model_modelnet, model_scannet, model_URS = models
    scannet_dataloader, modelnet_dataloader = dataloaders
    modelnet_iter = iter(modelnet_dataloader)

    model_scannet.eval()
    # model_modelnet.eval()
    # model_URS.eval()

    start = time.time()
    val_loss=0
    pbar = tqdm(total=len(scannet_dataloader['val']) * config.world_size)
    #
    # validate_scans = torch.zeros(len(scannet_dataloader['val']))
    # for i, scannet_data in enumerate(scannet_dataloader['val']):
    #     labels_scannet = scannet_data['labels']
    #     labels = labels_scannet[:, 1].long()
    #     for ii, item in enumerate(config.model_list):
    #         index_chair = labels == scannet_utils.CLASS_LABELS.index(item)
    #         if index_chair.sum() != 0: validate_scans[i] = 1

    # APs = {}
    # APs['chair'] = []
    # APs['bed'] = []
    # APs['sofa'] = []
    # APs['table'] = []
    # APs['bookshelf'] = []
    # APs['desk'] = []

    # APs = torch.zeros(len(config.model_list_validation)).float().to(config.device)
    # counts = torch.zeros(len(config.model_list_validation)).float().to(config.device)
    APs = torch.zeros(len(config.model_list_zero_shot)).float().to(config.device)
    counts = torch.zeros(len(config.model_list_zero_shot)).float().to(config.device)

    word2vector = np.load(config.word2vec, allow_pickle=True).item()
    word_embeddings = []
    # for item in config.model_list_validation:
    for item in config.model_list_zero_shot:
        word_embeddings.append(torch.from_numpy(word2vector[item]).to(config.device).unsqueeze(0))
    word_embeddings = torch.cat(word_embeddings, dim=0)


    mean_posi = []
    mean_nega = []
    num_posi = []
    num_nega = []

    for i, scannet_data in enumerate(scannet_dataloader['val']):
        optimizer.zero_grad()
        # if validate_scans[i] == 0: continue
        # if i not in validate_scans: continue
        # try:
        #     modelnet_data = modelnet_iter.next()
        # except:
        #     del modelnet_iter
        #     modelnet_iter = iter(modelnet_dataloader)
        #     modelnet_data = modelnet_iter.next()

        # coords_modelnet, feats_modelnet, labels_modelnet = modelnet_data['coords'], modelnet_data['feats'], \
        #                                                    modelnet_data['labels']
        # in_field_modelnet = get_in_field(coords_modelnet, feats_modelnet, config)
        # sinput_modelnet = in_field_modelnet
        # soutput_modelnet = model_modelnet(sinput_modelnet)
        soutput_modelnet = 0
        coords_scannet, feats_scannet, labels_scannet = scannet_data['coords'], scannet_data['feats'], scannet_data[
            'labels']
        # labels: instance_id, label_id
        in_field_scannet = get_in_field(coords_scannet, feats_scannet, config)
        # mask = labels_scannet[:, 1] >= 0

        labels = labels_scannet[:, 1].long().to(config.device)

        predictions, loss = model_scannet(coords_scannet, in_field_scannet, labels, word_embeddings, 'val')

        # print('before: ', predictions.size(), predictions[0, :])
        #
        # dist.all_reduce(predictions)
        # print('after: ', predictions.size(), predictions[0, :])

        predictions = F.softmax(predictions, dim=1)

        # for ii, item in enumerate(config.model_list_validation):
        for ii, item in enumerate(config.model_list_zero_shot):
            index = labels == scannet_utils.CLASS_LABELS.index(item)
            if index.sum() == 0: continue
            labels_trans = torch.zeros(labels.size()).long()
            labels_trans[index] = 1

            AP = evaluate.average_precision(predictions[:, ii].unsqueeze(0).detach().cpu().numpy(),
                                            labels_trans.unsqueeze(0).detach().cpu().numpy())
            # visual_prediction(coords_scannet, labels_scannet, 1 - predictions[:, ii], i, item, config)
            APs[ii] += AP[0]
            counts[ii] += 1

        # APs += AP
        torch.cuda.empty_cache()
        val_loss+=0
        pbar.update(config.world_size)

    pbar.close()
    torch.cuda.empty_cache()
    # print('val loss',val_loss/(len(scannet_dataloader['val'])+1), "mAP: ", mAP, 'time=',time.time() - start,'s')
    # print('val loss: %s mAP: %s Best accuracy: %s time: %s' % (val_loss/(len(scannet_dataloader['val'])+1), mAP, best_accuracy, time.time() - start))

    dist.all_reduce(APs)
    dist.all_reduce(counts)
    counts[counts == 0] = 1
    APs = APs / counts
    APs = APs.detach().cpu().numpy()
    mAP = APs.mean()

    if config.device != 0: return


    global best_accuracy
    if mAP > best_accuracy['average']:
        best_accuracy['average'] = mAP
        # torch.save(model_scannet.module.state_dict(), 'model_scannet_test_1.pth')
        #torch.save(model_URS.state_dict(), 'model_URS_test.pth')
        # torch.save(optimizer.state_dict(), 'optimizer_test.pth')

    print('val loss: %s average accuracy: %s Best accuracy: %s time: %s' % (val_loss/(len(scannet_dataloader['val']) * config.world_size),
                                                                            mAP, best_accuracy['average'], time.time() - start))
    # for ii, item in enumerate(config.model_list_validation):
    for ii, item in enumerate(config.model_list_zero_shot):
        if item not in best_accuracy: best_accuracy[item] = 0
        best_accuracy[item] = max(best_accuracy[item], APs[ii])
        print('type: %s, mAP: %s, best: %s' % (item, APs[ii], best_accuracy[item]))

best_iou_unseen = 0
best_iou_seen = 0
def val_iou(dataloaders, models, optimizer, criterion, config):
    model_modelnet, model_scannet, model_URS, generator = models
    scannet_dataloader, modelnet_dataloader = dataloaders
    modelnet_iter = iter(modelnet_dataloader)

    model_scannet.eval()
    # model_modelnet.eval()
    # model_URS.eval()

    word2vector = np.load(config.word2vec, allow_pickle=True).item()
    word_embeddings = []
    # for item in config.model_list_validation:
    for item in config.model_list_zero_shot:
        word_embeddings.append(torch.from_numpy(word2vector[item]).to(config.device).unsqueeze(0))
    word_embeddings = torch.cat(word_embeddings, dim=0)


    start = time.time()
    val_loss=0
    pbar = tqdm(total=len(scannet_dataloader['val']) * config.world_size)
    preds = []
    gts = []

    for i, scannet_data in enumerate(scannet_dataloader['val']):
        optimizer.zero_grad()

        coords_scannet, feats_scannet, labels_scannet = scannet_data['coords'], scannet_data['feats'], scannet_data[
            'labels']
        # labels: instance_id, label_id
        in_field_scannet = get_in_field(coords_scannet, feats_scannet, config)

        labels = labels_scannet[:, 1].long().to(config.device)

        # sinput_scannet = in_field_scannet.sparse()
        # soutput_scannet = model_scannet(sinput_scannet)
        # ofield_scannet = soutput_scannet.slice(in_field_scannet)
        # predictions = fine_tuning(ofield_scannet.F)

        # predictions, loss = model_scannet(in_field_scannet, labels, 'val')

        predictions, loss = model_scannet(coords_scannet, in_field_scannet, labels, word_embeddings, 'val')

        index = torch.argmax(predictions, dim=1) == 20
        # print(index.sum())
        predictions[index, :11] = -100000
        predictions = predictions[:, :20]

        # print('before: ', predictions.size(), predictions[0, :])
        #
        # dist.all_reduce(predictions)
        # print('after: ', predictions.size(), predictions[0, :])

        # predictions = F.softmax(predictions, dim=1)

        mask = labels >= 0

        pred = torch.argmax(predictions, dim=1)
        preds.append(pred[mask].view(-1).cpu())
        gts.append(labels[mask].view(-1).cpu())
        # import pdb
        # pdb.set_trace()

        pbar.update(config.world_size)

    gts = torch.cat(gts)
    gts_remapping = gts.clone()

    for i, item in enumerate(scannet_utils.CLASS_LABELS):
        gts_remapping[gts == i] = config.model_list_zero_shot.index(item)

    class_ious = evaluate.iou(torch.cat(preds), gts_remapping).to(config.device)
    dist.all_reduce(class_ious)

    class_ious = class_ious.cpu().numpy()
    pbar.close()


    if config.device != 0 : return

    print('classes          IoU')
    print('----------------------------')
    mean_iou = 0

    # for ii, item in enumerate(config.model_list_zero_shot):
    #     index = labels == scannet_utils.CLASS_LABELS.index(item)

    for i, item in enumerate(config.model_list_zero_shot):
        # label_name = scannet_utils.CLASS_LABELS[i]
        if item not in config.model_list: continue
        iou = class_ious[i, 0] / class_ious[i, 1]
        mean_iou += iou / len(config.model_list)
        print('{0:<14s}: {1:>5.3f}   ({2:>6f}/{3:<6f})'.format(item, iou, class_ious[i, 0], class_ious[i, 1]))
    print('mean IOU of seen classes', mean_iou)

    global best_iou_seen
    best_iou_seen = max(best_iou_seen, mean_iou)
    print(best_iou_seen)

    mean_iou = 0
    for i, item in enumerate(config.model_list_zero_shot):
        # label_name = scannet_utils.CLASS_LABELS[i]
        if item in config.model_list: continue
        iou = class_ious[i, 0] / class_ious[i, 1]
        mean_iou += iou / (len(scannet_utils.CLASS_LABELS) - len(config.model_list))
        print('{0:<14s}: {1:>5.3f}   ({2:>6f}/{3:<6f})'.format(item, iou, class_ious[i, 0], class_ious[i, 1]))
    print('mean IOU', mean_iou)

    global best_iou_unseen
    best_iou_unseen = max(best_iou_unseen, mean_iou)
    print(best_iou_unseen)
