import os
from absl import app, flags
from matplotlib import cm, colors
from tqdm import tqdm
import seaborn as sns
import numpy as np

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
print("CUDA visible device ID(s):", os.environ["CUDA_VISIBLE_DEVICES"])

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from data import AffordanceDataset
import model as model_utils
from util import ari, idx2afford, onehot2afford, visualize_voxels, visualize_voxels_cuboids, merge_test_image, hungarian_loss_meanIoU

workspace_dir = ""
workdir = os.path.join(workspace_dir, "sa/torch")

FLAGS = flags.FLAGS
flags.DEFINE_string("model_dir", 
                    os.path.join(workdir, "./"),
                    "Where to save the checkpoints.")
flags.DEFINE_integer("seed", 2333, "Random seed (prime preferred).")
flags.DEFINE_integer("epoch", 3000, "Which epoch to be evaluated.")
flags.DEFINE_integer("batch_size", 1,
                     "Batch size for the model on a single GPU.")
flags.DEFINE_bool("with_cuboid", True,
                     "If using the cuboid prediction branch.")
flags.DEFINE_bool("with_afford", True,
                     "If using the affordance prediction branch.")
flags.DEFINE_integer("num_slots", 4, "Number of slots in Slot Attention.")
flags.DEFINE_integer("num_iterations", 3, "Number of attention iterations.")

def main(argv):
    del argv
    # set test device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')
    print("Let's use", torch.cuda.device_count(), "GPU(s)!")
    
    # hyperparameters of the model
    model_dir = FLAGS.model_dir
    epoch = FLAGS.epoch
    batch_size = FLAGS.batch_size*torch.cuda.device_count()
    with_cuboid = FLAGS.with_cuboid
    with_afford = FLAGS.with_afford
    num_slots = FLAGS.num_slots
    num_iterations = FLAGS.num_iterations
    torch.manual_seed(FLAGS.seed)
    resolution = (32,32,32)
    hist_thld = .1
    combined_thld = .5
    slotwise_thld = .5

    # load the affordance dataset
    dataset = AffordanceDataset(workspace_dir, data_type='voxel', num_slots=num_slots, cross="_sittable")
    train_size = int(len(dataset) * .7)
    val_size = int(len(dataset) * .1)
    test_size = len(dataset) - train_size - val_size
    print("Number of test shapes:", test_size)
    train_set, val_set, test_set = torch.utils.data.random_split(dataset, 
                            [train_size, val_size, test_size], generator=torch.Generator().manual_seed(FLAGS.seed))
    test_dataloader = DataLoader(test_set, batch_size, shuffle=False)
    vis_dataloader = DataLoader(test_set, 1, shuffle=False)
    print('Dataset loaded.')

    # load model
    model = model_utils.SlotCuboidVox(resolution, num_slots, num_iterations, with_cuboid=with_cuboid, with_afford=with_afford)
    model = nn.DataParallel(model).to(device)
    model.load_state_dict(torch.load(os.path.join(model_dir, 'epoch_%d.pth'%epoch)))

    # metric calculation
    with torch.no_grad():
        test_data_list = list(enumerate(test_dataloader))
        aris = []
        bce = nn.BCEWithLogitsLoss(reduction='none')
        mse = nn.MSELoss()
        bces = []
        mses = []
        meanIoUs = []
        accs = []
        for batch_idx, test_batch in tqdm(test_data_list):
            # break
            shape = test_batch['geo'].to(device)
            shape[shape>0] = 1
            pred_dict = model(shape)
            recon_combined = pred_dict['recon_combined']    # [num_slots, width, height, depth, 1].
            recons = pred_dict['recons']    # [num_slots, width, height, depth, 1].
            masks = pred_dict['masks']      # [num_slots, width, height, depth, 1].
            slots = pred_dict['slots'] 
            if with_cuboid:
                center = pred_dict['center']    # [num_slots, 3]
                scale = pred_dict['scale']      # [num_slots, 3]
                rotate = pred_dict['rotate']    # [num_slots, 3, 3]
                vox_rotated = pred_dict['vox_rotated']
                vox_proj = pred_dict['vox_proj']
            if with_afford:
                one_hots = pred_dict['one_hots']
            vox_truncated = torch.sigmoid(recon_combined)
            vox_truncated[vox_truncated<combined_thld] = 0
            slot_contrib_idx = torch.argmax(recons * masks, dim=1)[...,0]
            slot_contrib_idx[vox_truncated<combined_thld] = -1
            batch_ari = ari(test_batch['geo'].to(device), slot_contrib_idx, device=device)
            aris.append(batch_ari)  
            bces.append(bce(recon_combined, shape))
            mses.append(mse(torch.sigmoid(recon_combined), shape))
            gt_onehots = []
            # batch size must be ONE here.
            unique_afford = torch.unique(test_batch['geo'])[1:]
            for afford_idx in range(len(unique_afford)):
                recon_i_excl = test_batch['geo'].clone()
                recon_i_excl[test_batch['geo']!=unique_afford[afford_idx]] = 0
                recon_i_excl[test_batch['geo']==unique_afford[afford_idx]] = 1
                gt_onehots.append(recon_i_excl)
            for _ in range(num_slots-len(unique_afford)):
                gt_onehots.append(test_batch['geo'].clone()*0)


            if with_afford:
                a = torch.unique(test_batch['geo'])
                a = a[a.nonzero()].int()
                b = torch.unique(torch.argmax(one_hots, dim=-1)[0]).cpu()
                b = b[b.nonzero()].int()
                if not torch.equal(a,b):
                    accs.append(0)
                else:
                    accs.append(1)


            recon_onehots = []
            for slot_idx in range(num_slots):
                recon_i_excl = slot_contrib_idx.clone()
                recon_i_excl[slot_contrib_idx!=slot_idx] = 0
                recon_i_excl[slot_contrib_idx==slot_idx] = 1
                recon_onehots.append(recon_i_excl)
            meanIoU = hungarian_loss_meanIoU(torch.stack(gt_onehots,dim=1).view(1,num_slots,-1).to(device), 
                                            torch.stack(recon_onehots,dim=1).view(1,num_slots,-1).to(device))
            meanIoUs.append(meanIoU.item())
            # break
        print("meanIoU: %.4f" % np.mean(meanIoUs))
        # print("vox_recon_loss: %.4f" % torch.mean(torch.concat(bces)).item())
        # # print('bce', torch.concat(bces).shape)
        # print('aris', torch.concat(aris))
        # # print('Center:', center[:3].cpu().numpy())
        # print('ARI: %4f'%torch.mean(torch.concat(aris)).item())
        print('mse', torch.mean(torch.FloatTensor(mses)))
        if with_afford:
            print('ap', np.mean(accs))

    vis = True
    if not vis:
        return
    # example visualization
    start_idx = 0
    end_idx = 20
    with torch.no_grad():
        for batch_idx, test_batch in list(enumerate(vis_dataloader))[start_idx:end_idx]:
            print('\nTest example %d with anno_id: %s'%(batch_idx, test_batch['anno_id'][0]))
            shape = test_batch['geo'].to(device)
            shape[shape>0] = 1
            pred_dict = model(shape)
            recon_combined = pred_dict['recon_combined']    # [num_slots, width, height, depth, 1].
            recons = pred_dict['recons']    # [num_slots, width, height, depth, 1].
            masks = pred_dict['masks']      # [num_slots, width, height, depth, 1].
            slots = pred_dict['slots'] 
            if with_cuboid:
                center = pred_dict['center']    # [num_slots, 3]
                scale = pred_dict['scale']      # [num_slots, 3]
                rotate = pred_dict['rotate']    # [num_slots, 3, 3]
                vox_rotated = pred_dict['vox_rotated']
                vox_proj = pred_dict['vox_proj']
            if with_afford:
                one_hots = pred_dict['one_hots']

            # reconstruction loss of the combined shape
            bce = nn.BCEWithLogitsLoss()
            print("vox_recon_loss: %.3f" % bce(recon_combined, shape).item())

            # visualize the ground truth and the reconstructed shape
            eval_dir = os.path.join(model_dir, "eval_%d/%04d"%(epoch, batch_idx))
            if not os.path.exists(eval_dir):
                os.makedirs(eval_dir)
            visualize_voxels(shape[0].cpu().numpy(), out_file=os.path.join(eval_dir, "gt.png"))
            vox_truncated = torch.sigmoid(recon_combined)[0].cpu().numpy()
            hist = sns.histplot(vox_truncated[vox_truncated>hist_thld].flatten())
            fig = hist.get_figure()
            fig.savefig(os.path.join(eval_dir, "prob_hist.png"))
            fig.clf() # clear the figure
            vox_truncated[vox_truncated<combined_thld] = 0
            visualize_voxels(vox_truncated, out_file=os.path.join(eval_dir, "recon.png"))

            # calculate which indices of voxels in each slot contribute to the combined shape
            # [32, 32, 32]
            slot_contrib_idx = torch.argmax(recons * masks, dim=1)[0,:,:,:,0].cpu().numpy()
            slot_contrib_idx[vox_truncated<combined_thld] = -1

            # visualize voxels for each slot (for recon only model)
            if with_afford:
                print("affordance:", one_hots[0])
                print(test_batch['affordance'][0])
            if not with_cuboid:
                for i in range(num_slots):
                    recon_i_excl = slot_contrib_idx.copy()
                    recon_i_excl[slot_contrib_idx!=i] = 0
                    recon_i_excl[slot_contrib_idx==i] = 1
                    recon_i = recons[0][i] * masks[0][i] # + masks[i]
                    recon_i = torch.sigmoid(recon_i).squeeze().cpu().numpy()
                    hist = sns.histplot(recon_i[recon_i>hist_thld].flatten())
                    fig = hist.get_figure()
                    fig.savefig(os.path.join(eval_dir, "recon_%d_hist.png"%i))
                    fig.clf() # clear the figure
                    recon_i[recon_i<slotwise_thld] = 0
                    if with_afford:
                        afford_label = onehot2afford(one_hots[0][i].cpu().numpy())
                    else:
                        afford_label = None
                    visualize_voxels(recon_i_excl, afford_label=afford_label, 
                                out_file=os.path.join(eval_dir, "recon_%d.png"%i))
                    visualize_voxels(recon_i, afford_label=afford_label, 
                                out_file=os.path.join(eval_dir, "recon_%d_slotwise.png"%i))
                merge_test_image(eval_dir, num_slots=num_slots)
                np.savez(os.path.join(eval_dir, "%04d.npz"%batch_idx), 
                        gt=test_batch['geo'][0].numpy(),
                        recon=slot_contrib_idx,
                        afford=np.argmax(one_hots[0].cpu().numpy(),axis=1) if with_afford else None)
                continue

            # cuboid vertices calculation
            print('center', center)
            print('scale', scale)
            cube_vert = torch.FloatTensor([[-1,-1,-1],[-1,-1,1],[-1,1,1],[-1,1,-1],
                                           [1,-1,-1],[1,-1,1],[1,1,1],[1,1,-1]]).to(device)
            verts_forward = cube_vert.unsqueeze(0).repeat(num_slots,1,1) * scale[0].unsqueeze(1).repeat(1,cube_vert.shape[0],1)
            verts_forward = torch.einsum('acd,ade->ace', rotate[0], verts_forward.permute(0,2,1)) 
            verts_forward = verts_forward.permute(0,2,1)
            verts_forward = verts_forward + center[0].unsqueeze(1).repeat(1,cube_vert.shape[0],1)
            cuboid_vertices = verts_forward.cpu().numpy()

            # visualize all the cuboids
            color_map = [colors.rgb2hex(cm.get_cmap('tab20')(rgba_idx)) for rgba_idx in range(0,20,2)]
            xyz_range = visualize_voxels_cuboids(vox_truncated, cuboid_vertices, cuboid_color=color_map, out_file=os.path.join(eval_dir, "recon_both.png"))

            # visualize the cuboid and voxels for each slot
            for i in range(num_slots):
                recon_i_excl = slot_contrib_idx.copy()
                recon_i_excl[slot_contrib_idx!=i] = 0
                recon_i_excl[slot_contrib_idx==i] = 1
                recon_i = recons[0][i] * masks[0][i]
                recon_i = torch.sigmoid(recon_i).squeeze().cpu().numpy()
                recon_i[recon_i<slotwise_thld] = 0
                if with_afford:
                    afford_label = onehot2afford(one_hots[0][i].cpu().numpy())
                else:
                    afford_label = None
                visualize_voxels(recon_i, afford_label=afford_label, out_file=os.path.join(eval_dir, "recon_%d_slotwise.png"%i))
                xyz_range = visualize_voxels_cuboids(recon_i_excl, cuboid_vertices[i][np.newaxis], afford_label=afford_label, 
                            cuboid_color=[color_map[i]], xyz_range=xyz_range, out_file=os.path.join(eval_dir, "recon_%d.png"%i))
            
            np.savez(os.path.join(eval_dir, "%04d.npz"%batch_idx), 
                    gt=test_batch['geo'][0].numpy(),
                    recon=slot_contrib_idx,
                    afford=np.argmax(one_hots[0].cpu().numpy(),axis=1) if with_afford else None,
                    cuboid=cuboid_vertices)
            merge_test_image(eval_dir, num_slots=num_slots)

if __name__ == "__main__":
    app.run(main)
