import argparse
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from dataloaders.MultiMeshDataset import MultiMeshDataset
from pl_models.OrienterTrainerModel import OrienterTrainerModel
from ml_models.orienter_model.DGCNNOrienter import DGCNNOrienter
from utils.helpers import get_timestamp
from utils.inference_helpers import voting_scheme
from utils.losses import abs_cos_loss
import json5

torch.cuda.current_device()
torch.cuda._initialized = True
torch.multiprocessing.set_sharing_strategy('file_system') # to avoid "too many open files" error

def octahedral_invariant_loss(up_predicted, front_predicted, target_rotation_matrices):
    """Compute a loss that is invariant to the octahedral symmetries of the rotation matrices."""
    # compute abs cos loss for all 6 possible permutations of columns of the target_rotation_matrices
    # the best match is the one that minimizes the loss
    perm_list = [(0,1,2), (0,2,1), (1,0,2), (1,2,0), (2,0,1), (2,1,0)]
    up_losses = []
    front_losses = []
    for perm in perm_list:
        target_rotation_matrices_perm = target_rotation_matrices[...,perm]
        perm_up_loss, perm_front_loss = abs_cos_loss(up_predicted, front_predicted, target_rotation_matrices_perm)
        up_losses.append(perm_up_loss)
        front_losses.append(perm_front_loss)
    up_losses = torch.stack(up_losses, dim=-1) # (B, 6)
    up_loss, _ = torch.min(up_losses, dim=-1) # (B,)
    front_losses = torch.stack(front_losses, dim=-1) # (B, 6)
    front_loss, _ = torch.min(front_losses, dim=-1) # (B,)
    return up_loss, front_loss

def compute_losses(batch, orienter_model):
    """Mimics validation step from our orienter-3d."""
    data_indices, xyzs_rotated, target_rotation_matrices, normals_rotated = batch
    # squeeze batch dimension and move to cuda
    xyzs_rotated = xyzs_rotated.squeeze().cuda()
    normals_rotated = normals_rotated.squeeze().cuda()
    target_rotation_matrices = target_rotation_matrices.cuda()
    # run voting scheme
    up_winner, front_winner = voting_scheme(xyzs_rotated.squeeze(), normals_rotated.squeeze(), orienter_model, num_candidates=50)
    # unsqueeze
    up_winner = up_winner.unsqueeze(0)
    front_winner = front_winner.unsqueeze(0)

    # compute the octahedral invariant loss
    up_octahedral_loss, front_octahedral_loss = octahedral_invariant_loss(up_winner, front_winner, target_rotation_matrices)
    # print the mean losses
    mean_up_octahedral_loss = up_octahedral_loss.mean()
    mean_front_octahedral_loss = front_octahedral_loss.mean()
    print("Mean up octahedral loss:", mean_up_octahedral_loss)
    print("Mean front octahedral loss:", mean_front_octahedral_loss)

    return up_octahedral_loss, front_octahedral_loss, xyzs_rotated, target_rotation_matrices

def main():
    # parse and load specs
    parser = argparse.ArgumentParser()
    parser.add_argument("--exp_dir", "-e", default="/home/ubuntu/orienter-3d/config/default", help="Path to specs.json5")
    parser.add_argument("--slurm_id", "-s", default=get_timestamp(), help="Path to specs.json5")
    parser.add_argument("--train_index_file", "-ti", default="data/shapenet_index_files/all_point_clouds/train.txt", help="Path to train index file")
    parser.add_argument("--val_index_file", "-vi", default="data/shapenet_index_files/all_point_clouds/val.txt", help="Path to val index file")
    parser.add_argument("--inference_index_file", "-ii", default="data/shapenet_index_files/all_point_clouds/inference.txt", help="Path to inference index file")
    parser.add_argument("--all_index_file", "-ai", default="data/shapenet_index_files/all_point_clouds/all.txt", help="Path to index file for all point clouds")
    parser.add_argument("--preload", "-p", action='store_true', help="Preload meshes into memory at initialization")
    parser.add_argument("--ckpt_path", "-ck", default="pretrained_ckpts/orienter/trained_on_prod_389_epochs.ckpt", help="Path of checkpoint storing the model")
    args = parser.parse_args()    
    slurm_id = args.slurm_id
    exp_dir = args.exp_dir.rstrip(" /")
    train_index_file_path = args.train_index_file
    val_index_file_path = args.val_index_file
    inference_index_file_path = args.inference_index_file
    all_index_file_path = args.all_index_file
    preload = args.preload
    ckpt_path = args.ckpt_path
    print(f"{exp_dir=}")
    print(f"{slurm_id=}") 
    print(f"{train_index_file_path=}")
    print(f"{val_index_file_path=}")
    print(f"{inference_index_file_path=}")
    print(f"{all_index_file_path=}")
    print(f"{preload=}")
    print(f"{ckpt_path=}")
    with open(os.path.join(exp_dir, "specs.json5"), "r") as file:
        specs = json5.load(file)
    specs["exp_dir"] = exp_dir

    # Load PL module from checkpoint
    dgcnn_args = argparse.Namespace()
    dgcnn_args.k = 20
    dgcnn_args.emb_dims = 1024
    dgcnn_args.dropout = 0.5
    core_model = DGCNNOrienter(dgcnn_args, rotation_representation="procrustes").cuda()

    val_dataloader = DataLoader(MultiMeshDataset(index_file_path = val_index_file_path, sample_size = 2000, preload=False), 
                                                batch_size = 1, # max we can handle
                                                shuffle = False,
                                                num_workers = 1,
                                                persistent_workers = True # else there's overhead on switch
                                                )
    
    # Load model from checkpoint
    trainer_module = OrienterTrainerModel.load_from_checkpoint(ckpt_path,
                                                               specs = specs,
                                                               core_model = core_model, 
                                                               train_loss_fn = "octahedral_invariant",
                                                               rotation_representation = "procrustes",
                                                               train_index_file_path = train_index_file_path,
                                                               val_index_file_path = val_index_file_path,
                                                               inference_index_file_path = inference_index_file_path,
                                                               preload = False,
                                                               num_points_per_cloud = 2000,
                                                               train_batch_size = 48,
                                                               val_batch_size = 48,
                                                               unlock_every_k_epochs = 10,
                                                               start_lr = 1e-4
                                                               )
    orienter_model = trainer_module.model.cuda()
    orienter_model.eval()

    # record losses
    up_octahedral_losses = []
    front_octahedral_losses = []

    # record point clouds and target rotation matrices
    xyzs_rotated_list = []
    target_rotation_matrices_list = []

    for i, batch in enumerate(tqdm(val_dataloader)):
        with torch.no_grad():
            up_octahedral_loss, front_octahedral_loss, xyzs_rotated, target_rotation_matrices = compute_losses(batch, orienter_model)
            up_octahedral_losses.append(up_octahedral_loss)
            front_octahedral_losses.append(front_octahedral_loss)
            xyzs_rotated_list.append(xyzs_rotated)
            target_rotation_matrices_list.append(target_rotation_matrices)
    
    up_octahedral_losses = torch.cat(up_octahedral_losses, dim=0)
    front_octahedral_losses = torch.cat(front_octahedral_losses, dim=0)
    xyzs_rotated_all = torch.cat(xyzs_rotated_list, dim=0)
    target_rotation_matrices_all = torch.cat(target_rotation_matrices_list, dim=0)

    # Compute mean and std of losses
    up_octahedral_losses_mean = up_octahedral_losses.mean()
    up_octahedral_losses_std = up_octahedral_losses.std()
    front_octahedral_losses_mean = front_octahedral_losses.mean()
    front_octahedral_losses_std = front_octahedral_losses.std()

    print("Mean up octahedral loss:", up_octahedral_losses_mean)
    print("Std up octahedral loss:", up_octahedral_losses_std)
    print("Mean front octahedral loss:", front_octahedral_losses_mean)
    print("Std front octahedral loss:", front_octahedral_losses_std)

    # Save losses, point clouds, rotation matrices
    results_dir = "benchmark_results/orienter"
    os.makedirs(results_dir, exist_ok=True)
    np.save(os.path.join(results_dir, "up_octahedral_losses.npy"), up_octahedral_losses.cpu().numpy())
    np.save(os.path.join(results_dir, "front_octahedral_losses.npy"), front_octahedral_losses.cpu().numpy())
    np.save(os.path.join(results_dir, "xyzs_rotated_all.npy"), xyzs_rotated_all.cpu().numpy())
    np.save(os.path.join(results_dir, "target_rotation_matrices_all.npy"), target_rotation_matrices_all.cpu().numpy())

if __name__ == "__main__":
    main()