# -*- coding: utf-8 -*-
import argparse
import os
import time
from tqdm import tqdm

import torch
import open3d as o3d
from torch.utils.data import DataLoader
import sys; sys.path.append(os.getcwd())
import opencood.hypes_yaml.yaml_utils as yaml_utils
from opencood.tools import train_utils, inference_utils
from opencood.data_utils.datasets import build_dataset
from opencood.utils import eval_utils
from opencood.visualization import vis_utils
import matplotlib.pyplot as plt
from opencood.visualization import simple_vis
from opencood.tools.pytorch_mem_utils import MemTracker
from opencood.visualization.cppc_vis import VisUtil
from opencood.utils.fsd_metric import MetricUtil


def test_parser():
    parser = argparse.ArgumentParser(description="synthetic data generation")
    parser.add_argument('--model_dir', type=str, required=True,
                        help='Continued training path')
    parser.add_argument('--fusion_method', required=False, type=str,
                        default='intermediate',
                        help='late, early or intermediate')
    parser.add_argument('--com_metric', action='store_true',
                        help='whether compute detection metric of predicted 3d box')
    parser.add_argument('--save_vis_n', type=int, default=-1,
                        help='save visualization result interval')
    parser.add_argument('--save_vis_type', type=str, default='bev',
                        help='save visualization result on bev or 3d')
    parser.add_argument('--save_npy', action='store_true',
                        help='whether to save prediction and gt result'
                             'in npy_test file')
    parser.add_argument("--hypes_yaml", type=str, default=None,
                        help='data generation yaml file needed ')
    parser.add_argument('--debug', action='store_true',
                        help="open debug mode(gpu mem tracjer...)")
    parser.add_argument('--use_mask', action='store_true',
                        help="use coop label mask")
    opt = parser.parse_args()
    return opt


def main():
    opt = test_parser()
    assert opt.fusion_method in ['late', 'early', 'intermediate']

    if opt.model_dir.endswith(".pth"):
        model_pth = opt.model_dir.split("/")[-1]
        opt.model_dir = "/".join(opt.model_dir.split("/")[:-1])
    else:
        model_pth = None
    
    if opt.hypes_yaml is not None:
        hypes = yaml_utils.load_yaml(opt.hypes_yaml)
    else:
        hypes = yaml_utils.load_yaml(None, opt)
    # hypes['validate_dir'] = hypes['test_dir']
    # assert "test" in hypes['validate_dir']
    left_hand = True if "OPV2V" in hypes['validate_dir'] else False
    print(f"Left hand visualizing: {left_hand}")

    print('Dataset Building')
    opencood_dataset = build_dataset(hypes, visualize=True, train=False)
    print(f"{len(opencood_dataset)} samples found.")
    data_loader = DataLoader(opencood_dataset,
                             batch_size=1,
                             num_workers=8,
                             collate_fn=opencood_dataset.collate_batch_test,
                             shuffle=False,
                             pin_memory=True,
                             drop_last=False)

    print('Creating Model')
    model = train_utils.create_model(hypes)
    # we assume gpu is necessary
    if torch.cuda.is_available():
        model.cuda()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print('Loading Model from checkpoint')
    saved_path = opt.model_dir
    if opt.save_vis_n != -1 and not os.path.exists(os.path.join(saved_path, 'val_vis')):
        os.mkdir(os.path.join(saved_path, 'val_vis'))
    _, model = train_utils.load_saved_model(
        saved_path, 
        model,
        model_pth=model_pth, 
        seg_pretrain=hypes['train_params'].get('seg_pretrain', False)
    )
    
    if opt.debug:
        mem_tracker = MemTracker()
        mem_tracker.create_track_thread()
    
    model.eval()

    # Create the dictionary for evaluation
    VisUtil.set_args(
        root_path=os.path.join(saved_path, 'val_vis'),
        left_hand=left_hand,
        is_vis=opt.save_vis_n != -1,
        save_vis_n=opt.save_vis_n,
    )
    MetricUtil.set_args(
        thr=20,
        is_use_mask=opt.use_mask,
    )
    MetricUtil.load_from_json()

    for i, batch_data in tqdm(enumerate(data_loader)):
        if opt.debug:
            mem_tracker.record_epoch(0, i, train=False)
        # print(i)
        with torch.no_grad():
            batch_data = train_utils.to_device(batch_data, device)
            batch_data['ego']['metas'] = {'model_dir': saved_path, 'vis_dir': os.path.join(saved_path, 'val_vis'),
                                            'batch_idx': i, 'vis_n': opt.save_vis_n, 'vis_type': opt.save_vis_type,
                                            'proj_first': opencood_dataset.proj_first}
            if opt.fusion_method == 'late':
                pred_dicts = \
                    inference_utils.inference_late_fusion(batch_data,
                                                          model,
                                                          opencood_dataset)
            elif opt.fusion_method == 'early':
                pred_dicts = \
                    inference_utils.inference_early_fusion(batch_data,
                                                           model,
                                                           opencood_dataset)
            elif opt.fusion_method == 'intermediate':
                pred_dicts = \
                    inference_utils.inference_intermediate_fusion(batch_data,
                                                                  model,
                                                                  opencood_dataset)
            else:
                raise NotImplementedError('Only early, late and intermediate'
                                          'fusion is supported.')
                
            MetricUtil.eval_iter(i, pred_dicts)
            if VisUtil.is_vis_now():
                two_stage = True if isinstance(pred_dicts['pred_box'], dict) else False
                pred_bboxes, gt_bboxes = VisUtil.get_preds_and_gts_with_format(pred_dicts)
                VisUtil.set_args(
                    pred_bboxes=pred_bboxes,
                    gt_bboxes=gt_bboxes,
                )
                VisUtil.vis_with_stages(opt.save_vis_type, True, True, True, two_stage)

    if not MetricUtil.is_load_from_json:
        MetricUtil.save_json()
    MetricUtil.eval_mAp(opt.model_dir)
    MetricUtil.print_volume()
    
    if opt.debug:
        mem_tracker.end_track()


if __name__ == '__main__':
    main()
