import os, time, datetime
import numpy as np
import seaborn as sns
from absl import app, flags
from itertools import cycle

os.environ["CUDA_VISIBLE_DEVICES"] = "3,2,1,0"
print("CUDA visible device ID(s):", os.environ["CUDA_VISIBLE_DEVICES"])

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from data import AffordanceDataset
import model as model_utils
from util import hungarian_loss, optimizer_to_device, visualize_voxels

workspace_dir = ""
workdir = os.path.join(workspace_dir, "sa/torch")

FLAGS = flags.FLAGS
flags.DEFINE_string("model_dir", 
                    os.path.join(workdir, "tmp/"),
                    "Where to save the checkpoints.")
flags.DEFINE_string("ckpt_pth", "",
                    "Where to restore checkpoints.")
flags.DEFINE_integer("seed", 2333, "Random seed (prime preferred).")
flags.DEFINE_integer("batch_size", 128,
                     "Batch size for the model on a single GPU.")
flags.DEFINE_bool("with_cuboid", False,
                     "If using the cuboid prediction branch.")
flags.DEFINE_bool("with_afford", True,
                     "If using the affordance prediction branch.")
flags.DEFINE_integer("num_slots", 3, "Number of slots in Slot Attention.")
flags.DEFINE_integer("num_iterations", 3, "Number of attention iterations.")
flags.DEFINE_float("learning_rate", 0.0004*0.05, "Learning rate.")
flags.DEFINE_integer("num_train_steps", 500000, "Number of training steps.")

def forward_step(step_idx, batch, model, device, with_cuboid=True, with_afford=True):
    """Perform a single training step."""
    # [batchsize, 32, 32, 32]
    gt_geo = batch['geo'].to(device)
    gt_geo[gt_geo>0] = 1
    pred_dict = model(gt_geo)
    recon_combined = pred_dict['recon_combined']
    recons = pred_dict['recons']
    masks = pred_dict['masks']
    slots = pred_dict['slots']
    one_hots = pred_dict['one_hots']
    if with_cuboid:
        mask_surface = pred_dict['mask_surface']
        center = pred_dict['center']    # [batch_size, num_slots, 3]
        scale = pred_dict['scale']      # [batch_size, num_slots, 3]
        rotate = pred_dict['rotate']    # [batch_size, num_slots, 3, 3]
        vox_rotated = pred_dict['vox_rotated']
        vox_proj = pred_dict['vox_proj']

    bce = nn.BCEWithLogitsLoss()
    # bce(output, target) in torch
    vox_recon_loss = bce(recon_combined, gt_geo)

    zero_loss = torch.zeros(1).to(device)
    if with_afford:
        affordance_loss = hungarian_loss(one_hots, batch['affordance'].to(device)) * .5
    else:
        affordance_loss = zero_loss

    if not with_cuboid:
        total_loss = vox_recon_loss + affordance_loss
        return total_loss, vox_recon_loss, zero_loss, zero_loss, affordance_loss

    # vox_proj has shape: [B: batch_size, N: num_cuboids, n: num_points 32x32x32, 6 faces, 3 xyz]
    # calculate the distance between rotated voxel and projected voxel (on the cuboid face w/ min dist)
    B, N, n, _, _ = vox_proj.shape
    # calculate the distance from each point/voxel in a cuboid to six faces of that cuboid. [B, N, n, 6]
    dist = torch.sum((vox_rotated - vox_proj)**2, dim = -1)
    # for each point/voxel, find the minimum distance to a face. [B, N, n]
    min_dist, _ = torch.min(dist, dim = -1)
    weight = masks.view(B, N, n) * recons.view(B, N, n)
    weight = torch.sigmoid(weight)
    
    weight = mask_surface.view(B, N, n).detach() * weight

    weighted_min_dist = torch.sum(weight * min_dist, dim = -1)
    cuboid_loss = torch.mean(weighted_min_dist, dim = (0,1)) * .05

    scale_reg = nn.L1Loss()
    scale_reg_loss = scale_reg(scale, torch.zeros(B,N,3).to(device)) * 1e-2

    total_loss = vox_recon_loss + cuboid_loss + scale_reg_loss + affordance_loss


    return total_loss, vox_recon_loss, cuboid_loss, scale_reg_loss, affordance_loss

def main(argv):
    del argv
    # set training device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f'Using device: {device}')
    print("Let's use", torch.cuda.device_count(), "GPU(s)!")

    # hyperparameters of the model
    batch_size = FLAGS.batch_size*torch.cuda.device_count()
    with_cuboid = FLAGS.with_cuboid
    with_afford = FLAGS.with_afford
    num_slots = FLAGS.num_slots
    num_iterations = FLAGS.num_iterations
    base_learning_rate = FLAGS.learning_rate
    num_train_steps = FLAGS.num_train_steps
    torch.manual_seed(FLAGS.seed)
    resolution = (32,32,32)

    # load the affordance dataset
    dataset = AffordanceDataset(workspace_dir, data_type='voxel', num_slots=num_slots, rm=True, cross="_openable_sapien_simp")
    train_size = int(len(dataset) * .7)
    val_size = int(len(dataset) * .1)
    print("Number of training shapes:", train_size)
    test_size = len(dataset) - train_size - val_size
    train_set, val_set, test_set = torch.utils.data.random_split(dataset, 
                            [train_size, val_size, test_size], generator=torch.Generator().manual_seed(FLAGS.seed))
    train_dataloader = DataLoader(train_set, batch_size, shuffle=True)
    train_dataset_cycle = cycle(train_dataloader)
    val_dataloader = DataLoader(val_set, batch_size, shuffle=True)
    print('Dataset loaded.')

    # build model
    if not with_cuboid:
        print("Cuboid branch is not used.")
    model = model_utils.SlotCuboidVox(resolution, num_slots, num_iterations, with_cuboid=with_cuboid, with_afford=with_afford)
    model = nn.DataParallel(model).to(device)
    
    # initialize optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=base_learning_rate)
    optimizer_to_device(optimizer, device)

    if FLAGS.ckpt_pth:
        model.load_state_dict(torch.load(FLAGS.ckpt_pth), strict=False)

    # create directories for this run
    os.makedirs(FLAGS.model_dir)
    # file log
    summary_writer = SummaryWriter() 
    with open(os.path.join(FLAGS.model_dir, 'train.log'), 'w') as flog:
        flog.write("batch size: %d\n" % batch_size)
    # start training
    print("Starting training ...... ")
    print("batch size:", batch_size)
    start = time.time()
    old_epoch_idx = -1

    for step_idx in range(num_train_steps):
        train_batch = next(train_dataset_cycle)
        
        # forward pass 
        optimizer.zero_grad()
        total_loss, vox_recon_loss, cuboid_loss, scale_reg_loss, affordance_loss = forward_step(step_idx, train_batch, model, device, with_cuboid=with_cuboid, with_afford=with_afford)
        
        # optimize one step
        total_loss.backward()
        optimizer.step()

        if not step_idx % 100:
            summary_writer.add_scalar("total_loss", total_loss.item(), step_idx)
            summary_writer.add_scalar("recon_loss", vox_recon_loss.item(), step_idx)
            summary_writer.add_scalar("cuboid_loss", cuboid_loss.item(), step_idx)
            summary_writer.add_scalar("scale_reg_loss", scale_reg_loss.item(), step_idx)
            summary_writer.add_scalar("afford_loss", affordance_loss.item(), step_idx)
            print("Step: %s, Loss_voxrecon: %.6f, Loss_cuboid: %.6f, Loss_scale_reg: %.6f, Loss_afford: %.6f, Time: %s" %
                   (step_idx, vox_recon_loss.item(), cuboid_loss.item(), scale_reg_loss.item(), affordance_loss.item(),
                   datetime.timedelta(seconds=time.time() - start)))
            with open(os.path.join(FLAGS.model_dir, 'train.log'), 'a') as flog:
                flog.write("Step: %s, Loss_voxrecon: %.6f, Loss_cuboid: %.6f, Loss_scale_reg: %.6f, Loss_afford: %.6f, Time: %s\n" %
                   (step_idx, vox_recon_loss.item(), cuboid_loss.item(), scale_reg_loss.item(), affordance_loss.item(),
                   datetime.timedelta(seconds=time.time() - start)))

        epoch_idx = step_idx*batch_size//train_size
        if epoch_idx > old_epoch_idx:
            old_epoch_idx = epoch_idx
            if not epoch_idx % 50:
                print("Saving checkpoint at epoch %d ...... " %epoch_idx)
                with open(os.path.join(FLAGS.model_dir, 'train.log'), 'a') as flog:
                    flog.write("Saving checkpoint at epoch %d ...... \n" %epoch_idx)
                torch.save(model.state_dict(), os.path.join(FLAGS.model_dir, 'epoch_'+str(epoch_idx)+'.pth'))
                torch.save(optimizer.state_dict(), os.path.join(FLAGS.model_dir, 'epoch_'+str(epoch_idx)+'_opt.pth'))
                # validation
                with torch.no_grad():
                    avg_total = []; avg_recon = []; avg_cuboid = []; avg_reg = []; avg_afford = []
                    for val_batch in val_dataloader:
                        total_loss, vox_recon_loss, cuboid_loss, scale_reg_loss, affordance_loss = forward_step(step_idx, val_batch, model, device, with_cuboid=with_cuboid, with_afford=with_afford)
                        avg_total.append(total_loss.item())
                        avg_recon.append(vox_recon_loss.item())
                        avg_cuboid.append(cuboid_loss.item())
                        avg_reg.append(scale_reg_loss.item())
                        avg_afford.append(affordance_loss.item())
                    with open(os.path.join(FLAGS.model_dir, 'train.log'), 'a') as flog:
                        flog.write("Validation: Loss_voxrecon: %.6f, Loss_cuboid: %.6f, Loss_scale_reg: %.6f, Loss_afford: %.6f, Time: %s\n" %
                            (np.mean(avg_recon), np.mean(avg_cuboid), np.mean(avg_reg), np.mean(avg_afford),
                            datetime.timedelta(seconds=time.time() - start)))
                    summary_writer.add_scalar("val_total_loss", np.mean(avg_total), step_idx)
                    summary_writer.add_scalar("val_recon_loss", np.mean(avg_recon), step_idx)
                    summary_writer.add_scalar("val_cuboid_loss", np.mean(avg_cuboid), step_idx)
                    summary_writer.add_scalar("val_scale_reg_loss", np.mean(avg_reg), step_idx)
                    summary_writer.add_scalar("val_afford_loss", np.mean(avg_afford), step_idx)
                    
    summary_writer.close()

if __name__ == "__main__":
    app.run(main)
