from argparse import ArgumentParser

import torch
from tqdm import tqdm
import numpy as np
import cv2
import torchvision
import matplotlib as mpl
import os
from PIL import Image
mpl.use('Agg')
from fiery.utils.network import NormalizeInverse
from fiery.data import prepare_dataloaders
from fiery.trainer_polar_16 import TrainingModule
from fiery.metrics import IntersectionOverUnion, PanopticMetric
from fiery.utils.network import preprocess_batch
from fiery.utils.instance import predict_instance_segmentation_and_trajectories
import matplotlib.pyplot as plt
from fiery.utils.instance import predict_instance_segmentation_and_trajectories
from fiery.utils.visualisation import plot_instance_map, generate_instance_colours, make_contour, convert_figure_numpy


# 30mx30m, 100mx100m
EVALUATION_RANGES = {'30x30': (70, 130),
                     '100x100': (0, 200)
                     }


def plot_prediction(image, output, labels, cfg):
    labels = 1 - labels[0, 0, 0].view(200, 200, 1).repeat(1, 1, 3).cpu().long().numpy()
    labels = labels * 255
    # import pdb; pdb.set_trace()
    # Process predictions
    consistent_instance_seg, matched_centers = predict_instance_segmentation_and_trajectories(
        output, compute_matched_centers=True
    )
    # import pdb; pdb.set_trace()
    assert torch.unique(consistent_instance_seg[0, 0]).cpu().long().numpy()[0] == 0
    # Plot future trajectories
    unique_ids = torch.unique(consistent_instance_seg[0, 0]).cpu().long().numpy()[1:]
    instance_map = dict(zip(unique_ids, unique_ids))
    instance_colours = generate_instance_colours(instance_map)
    vis_image = plot_instance_map(consistent_instance_seg[0, 0].cpu().numpy(), instance_map)
    trajectory_img = np.zeros(vis_image.shape, dtype=np.uint8)
    for instance_id in unique_ids:
        path = matched_centers[instance_id]
        for t in range(len(path) - 1):
            color = instance_colours[instance_id].tolist()
            cv2.line(trajectory_img, tuple(path[t]), tuple(path[t + 1]),
                     color, 4)

    # Overlay arrows
    temp_img = cv2.addWeighted(vis_image, 0.7, trajectory_img, 0.3, 1.0)
    mask = ~ np.all(trajectory_img == 0, axis=2)
    vis_image[mask] = temp_img[mask]

    # Plot present RGB frames and predictions
    val_w = 2.99
    cameras = cfg.IMAGE.NAMES
    image_ratio = cfg.IMAGE.FINAL_DIM[0] / cfg.IMAGE.FINAL_DIM[1]
    val_h = val_w * image_ratio
    fig = plt.figure(figsize=(6 * val_w, 2 * val_h))
    width_ratios = (val_w, val_w, val_w, val_w, val_w, val_w)
    gs = mpl.gridspec.GridSpec(2, 6, width_ratios=width_ratios)
    gs.update(wspace=0.0, hspace=0.0, left=0.0, right=1.0, top=1.0, bottom=0.0)

    denormalise_img = torchvision.transforms.Compose(
        (NormalizeInverse(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
         torchvision.transforms.ToPILImage(),)
    )
    for imgi, img in enumerate(image[0, -1]):
        ax = plt.subplot(gs[imgi // 3, imgi % 3])
        showimg = denormalise_img(img.cpu())
        if imgi > 2:
            showimg = showimg.transpose(Image.FLIP_LEFT_RIGHT)

        plt.annotate(cameras[imgi].replace('_', ' ').replace('CAM ', ''), (0.01, 0.87), c='white',
                     xycoords='axes fraction', fontsize=14)
        plt.imshow(showimg)
        plt.axis('off')

    ax = plt.subplot(gs[:, 3])
    # plt.imshow(make_contour(vis_image[::-1, ::-1]))
    segmentation_pred = output['segmentation'].detach()
    segmentation_pred = torch.argmax(segmentation_pred, dim=2, keepdims=True)
    segmentation_pred = 1 - segmentation_pred[0, 0, 0].view(200, 200, 1).repeat(1, 1, 3).cpu().long().numpy()
    segmentation_pred = segmentation_pred * 255
    plt.imshow(make_contour(segmentation_pred[::-1, ::-1]))
    plt.axis('off')
    
    ax = plt.subplot(gs[:, 4])
    plt.imshow(make_contour(vis_image[::-1, ::-1]))
    # plt.imshow((vis_image))
    plt.axis('off')
    
    ax = plt.subplot(gs[:, 5])
    plt.imshow(make_contour(labels[::-1, ::-1]))
    # plt.imshow((labels))
    plt.axis('off')

    plt.draw()
    figure_numpy = convert_figure_numpy(fig)
    plt.close()
    return figure_numpy





def eval(checkpoint_path, dataroot, version):
    trainer = TrainingModule.load_from_checkpoint(checkpoint_path, strict=True)
    print(f'Loaded weights from \n {checkpoint_path}')
    trainer.eval()

    device = torch.device('cuda:0')
    trainer.to(device)
    model = trainer.model

    cfg = model.cfg
    cfg.GPUS = "[0]"
    cfg.BATCHSIZE = 1

    cfg.DATASET.DATAROOT = dataroot
    cfg.DATASET.VERSION = version

    _, valloader = prepare_dataloaders(cfg)

    panoptic_metrics = {}
    iou_metrics = {}
    n_classes = len(cfg.SEMANTIC_SEG.WEIGHTS)
    for key in EVALUATION_RANGES.keys():
        panoptic_metrics[key] = PanopticMetric(n_classes=n_classes, temporally_consistent=True).to(
            device)
        iou_metrics[key] = IntersectionOverUnion(n_classes).to(device)

    for i, batch in enumerate(tqdm(valloader)):
        preprocess_batch(batch, device)
        image = batch['image']
        intrinsics = batch['intrinsics']
        extrinsics = batch['extrinsics']
        lidar2imgs = batch['lidar2imgs']
        future_egomotion = batch['future_egomotion']
        
        batch_size = image.shape[0]
        # import torchvision
        # torchvision.

        labels, future_distribution_inputs = trainer.prepare_future_labels(batch)

        with torch.no_grad():
            # Evaluate with mean prediction
            noise = torch.zeros((batch_size, 1, model.latent_dim), device=device)
            # import pdb; pdb.set_trace()
            output, inter_seg, inter_instance_offset, inter_instance_center = model(image, intrinsics, extrinsics, lidar2imgs, future_egomotion,
                           future_distribution_inputs, noise=noise)
            
            figure_numpy = plot_prediction(image, output, labels['segmentation'], trainer.cfg)
            os.makedirs('./output_vis', exist_ok=True)
            output_filename = os.path.join('./output_vis', str(i)) + '.png'
            Image.fromarray(figure_numpy).save(output_filename)
            print(f'Saved output in {output_filename}')


        #  Consistent instance seg
        pred_consistent_instance_seg = predict_instance_segmentation_and_trajectories(
            output, compute_matched_centers=False, make_consistent=True
        )

        segmentation_pred = output['segmentation'].detach()
        segmentation_pred = torch.argmax(segmentation_pred, dim=2, keepdims=True)

        for key, grid in EVALUATION_RANGES.items():
            limits = slice(grid[0], grid[1])
            panoptic_metrics[key](pred_consistent_instance_seg[..., limits, limits].contiguous().detach(),
                                  labels['instance'][..., limits, limits].contiguous()
                                  )

            iou_metrics[key](segmentation_pred[..., limits, limits].contiguous(),
                             labels['segmentation'][..., limits, limits].contiguous()
                             )

    results = {}
    for key, grid in EVALUATION_RANGES.items():
        panoptic_scores = panoptic_metrics[key].compute()
        for panoptic_key, value in panoptic_scores.items():
            results[f'{panoptic_key}'] = results.get(f'{panoptic_key}', []) + [100 * value[1].item()]

        iou_scores = iou_metrics[key].compute()
        results['iou'] = results.get('iou', []) + [100 * iou_scores[1].item()]

    for panoptic_key in ['iou', 'pq', 'sq', 'rq']:
        print(panoptic_key)
        print(' & '.join([f'{x:.1f}' for x in results[panoptic_key]]))


if __name__ == '__main__':
    parser = ArgumentParser(description='Fiery evaluation')
    parser.add_argument('--checkpoint', default='./fiery.ckpt', type=str, help='path to checkpoint')
    parser.add_argument('--dataroot', default='./nuscenes', type=str, help='path to the dataset')
    parser.add_argument('--version', default='trainval', type=str, choices=['mini', 'trainval'],
                        help='dataset version')

    args = parser.parse_args()

    eval(args.checkpoint, args.dataroot, args.version)
