import argparse
import torch
import numpy as np
from utils import *
from dataset import *
from network import *
from networkLit import *
from torch_geometric.loader import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
pl.seed_everything(0, workers=True)
from pathlib import Path
ABS_PATH = os.path.dirname(os.path.abspath(__file__))


torch.manual_seed(0)
torch.cuda.is_available()
torch.set_float32_matmul_precision('high')

def parse_args():
    parser = argparse.ArgumentParser(description='Train GNN')

    parser.add_argument('--scene_path', type=str,
                        help='Path to 3d scene',
                        required=True)
    
    parser.add_argument('--node_features', type=str,
                        help='Node features',
                        required=False,
                        default="dim-pose")
    
    parser.add_argument('--edge_types', type=str,
                        help='Edge types',
                        required=False,
                        default="IK-GO")
    
    parser.add_argument('--edge_features', type=str,
                        help='Edge features',
                        required=False,
                        default="type-IK-GO")
    
    parser.add_argument('--IK_GO_mode', type=str,
                        help='IK GO mode',
                        required=False,
                        default="pred")
    
    parser.add_argument('--augmentations', type=str,
                        help='Augmentations',
                        required=False,
                        default="dimswitch_all")
    
    parser.add_argument('--device', type=str,
                        help='Device',
                        required=False,
                        default="cuda")
    
    parser.add_argument('--n_epochs', type=int,
                        help='Number of epochs',
                        required=False,
                        default=100)
    
    parser.add_argument('--batch_size', type=int,
                        help='Batch size',
                        required=False,
                        default=2048)
    
    parser.add_argument('--lr', type=float,
                        help='Learning rate',
                        required=False,
                        default=0.0001)
    
    parser.add_argument('--weight_decay', type=float,
                        help='Weight decay',
                        required=False,
                        default=0.0)
    
    parser.add_argument('--num_workers', type=int,
                        help='Number of workers',
                        required=False,
                        default=8)
    
    parser.add_argument('--num_node_features', type=int,
                        help='Number of node features',
                        required=False,
                        default=7)
    
    parser.add_argument('--num_edge_features', type=int,
                        help='Number of edge features',
                        required=False,
                        default=7)
    
    parser.add_argument('--hidden_size', type=int,
                        help='Hidden size',
                        required=False,
                        default=256)
    
    parser.add_argument('--num_heads', type=int,
                        help='Number of heads',
                        required=False,
                        default=4)
    
    parser.add_argument('--n_message_passing', type=int,
                        help='Number of message passing',
                        required=False,
                        default=1)
    
    parser.add_argument('--dropout', type=float,
                        help='Dropout',
                        required=False,
                        default=0.0)
    
    parser.add_argument('--pos_weight', type=float,
                        help='Positive weight',
                        required=False,
                        default=1)

    parser.add_argument('--debug', type=bool,
                        help='Positive weight',
                        required=False,
                        default=False)
    
    parser.add_argument('--gnn_type', type=str,
                        help='GNN type',
                        required=False,
                        default="EGAT")
    
    parser.add_argument('--training_mode', type=str,
                        help='Positive weight',
                        required=False,
                        default="one_by_one")
    
    args = parser.parse_args()

    return args

if __name__ == '__main__':
    args = parse_args()
    
    args.num_node_features = 7
    if "IK" in args.edge_features or "GO" in args.edge_features:
        args.num_edge_features = 7
    else:
        args.num_edge_features = 2

    hyperparameters = {"lr": args.lr, "batch_size": args.batch_size, "n_epochs": args.n_epochs, "hidden_size": args.hidden_size, "augmentations": args.augmentations,
                       "num_heads": args.num_heads, "n_message_passing": args.n_message_passing, "node_features": args.node_features, "edge_types": args.edge_types,
                       "edge_features": args.edge_features}

    model = GRNLit(args, hyperparameters).to(args.device)
    #model.model.load_state_dict(torch.load(os.path.join(ABS_PATH, "lightning_logs/GRN_" + args.dataset_path.split("/")[-1] + "_" + args.augmentations + "_" + args.edge_features + ".pt")))
    
    with open(args.scene_path, 'r') as f:
        scene = json.load(f)
    with torch.no_grad():
        feasibility_preds, IK_preds, GO_preds, data = model.model.predict_from_scene(scene)
    data.edge_index = data.edge_index[:, :(data.edge_index.shape[1] - len(data.x[data.movable_mask]))]
    
    visualize_scene(data, robot_mesh_path=os.path.join(ABS_PATH, "assets/panda.obj"))
    visualize_action_predictions(data, feasibility_preds, os.path.join(ABS_PATH, "assets/panda.obj"))
    visualize_grasp_predictions(data, feasibility_preds, os.path.join(ABS_PATH, "assets/panda.obj"))
    visualize_go_predictions(data, IK_preds, GO_preds, len(data.x)-1, os.path.join(ABS_PATH, "assets/panda.obj"))