# -*- coding: utf-8 -*-

import argparse
from tqdm import tqdm

import torch
from torch.utils.data import DataLoader

import opencood.hypes_yaml.yaml_utils as yaml_utils
from opencood.tools import train_utils, inference_utils
from opencood.data_utils.datasets import build_dataset
from opencood.utils.seg_utils import normalize_maps
from ignite.metrics import IoU
from ignite.metrics.confusion_matrix import ConfusionMatrix

CLASS_NAMES = ['Vehicle','Road','Road line']

def infer_parser():
    parser = argparse.ArgumentParser(description="synthetic data generation")
    parser.add_argument('--model_dir', type=str, required=False,
                        default='',
                        help='Continued training path')
    parser.add_argument('--fusion_method', required=False, type=str,
                        default='intermediate',
                        help='late, early or intermediate')
    parser.add_argument('--show_vis', action='store_true',
                        # default=True,
                        help='whether to show image visualization result')
    parser.add_argument('--show_sequence', action='store_true',
                        help='whether to show video visualization result.'
                             'it can note be set true with show_vis together ')
    parser.add_argument('--save_vis', action='store_true',
                        default=True,
                        help='whether to save visualization result')
    parser.add_argument('--save_npy', action='store_true',
                        help='whether to save prediction and gt result' 
                             'in npy_test file')
    parser.add_argument('--dataset_mode', type=str, default="")
    parser.add_argument('--epoch', default=None,
                        help="epoch number to load model")
    opt = parser.parse_args()
    return opt


if __name__ == '__main__':

    opt = infer_parser()
    assert opt.fusion_method in ['late', 'early', 'intermediate', "nofusion"]
    assert not (opt.show_vis and opt.show_sequence), 'you can only visualize ' \
                                                    'the results in single ' \
                                                    'image mode or video mode'

    hypes = yaml_utils.load_yaml(None, opt)
    if opt.dataset_mode:
        hypes['dataset_mode'] = opt.dataset_mode

    print(hypes['dataset_mode'])

    print('Dataset Building')
    # This should be modified
    opencood_dataset = build_dataset(hypes, visualize=True, train=False)
    print(f"{len(opencood_dataset)} samples found.")
    data_loader = DataLoader(opencood_dataset,
                             batch_size=1,
                             num_workers=16,
                             collate_fn=opencood_dataset.collate_batch_test,
                             shuffle=False,
                             pin_memory=False,
                             drop_last=False)

    print('Creating Model')
    model = train_utils.create_model(hypes)

    # we assume gpu is necessary
    if torch.cuda.is_available():
        model.cuda()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Creating evaluator
    cm = ConfusionMatrix(num_classes=4) # TODO: 4 classes of segmentation
    iou_metric = IoU(cm)

    print('Loading Model from checkpoint')
    saved_path = opt.model_dir
    _, model = train_utils.load_saved_model(saved_path, model, epoch=opt.epoch)
    model.eval()

    for i, batch_data in tqdm(enumerate(data_loader)):
        with torch.no_grad():
            batch_data = train_utils.to_device(batch_data, device)
            output_dict = model(batch_data['ego'])
            pred_map, gt_map = normalize_maps(output_dict, batch_data['ego']['seg_label'])
            iou_metric.update((output_dict['seg'], batch_data['ego']['seg_label'].to(int)))
            if opt.save_vis:
                inference_utils.camera_inference_visualization(pred_map, gt_map, saved_path, i, folder_name='test_vis', with_true_map=False) # train_vis, test_vis

    match_ious = iou_metric.compute().tolist()
    for j in range(1,4):
        print('Mean IoU in %s : %.4f' % (CLASS_NAMES[j-1], match_ious[j]))
