# ------------------------------------------------------------------------
# Modified from Deformable DETR
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# -----------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------
 
"""
Train and eval functions used in main.py
"""
import math
import os
import sys
from typing import Iterable
 
import torch
import util.misc as utils
from datasets.coco_eval import CocoEvaluator
from datasets.open_world_eval import OWEvaluator
from datasets.panoptic_eval import PanopticEvaluator
from datasets.data_prefetcher import data_prefetcher
from util.box_ops import box_xyxy_to_cxcywh, box_cxcywh_to_xyxy
from util.plot_utils import plot_prediction, plot_prediction_GT
import matplotlib.pyplot as plt
from copy import deepcopy

import json
import cv2


def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, nc_epoch: int, max_norm: float = 0, wandb: object = None):
    model.train()
    criterion.train()
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 10
    prefetcher = data_prefetcher(data_loader, device, prefetch=True)
    samples, targets = prefetcher.next()

    for _ in metric_logger.log_every(range(len(data_loader)), print_freq, header):
        outputs = model(samples)
        loss_dict = criterion(samples, outputs, targets, epoch) ## samples variable needed for feature selection
        weight_dict = deepcopy(criterion.weight_dict)
        
        ## condition for starting nc loss computation after certain epoch so that the F_cls branch has the time
        ## to learn the within classes seperation.
        if epoch < nc_epoch: 
            for k,v in weight_dict.items():
                if 'NC' in k:
                    weight_dict[k] = 0
         
        losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
        # reduce losses over all GPUs for logging purposes

        loss_dict_reduced = utils.reduce_dict(loss_dict)
        ## Just printing NOt affectin gin loss function
        loss_dict_reduced_unscaled = {f'{k}_unscaled': v
                                      for k, v in loss_dict_reduced.items()}
        loss_dict_reduced_scaled = {k: v * weight_dict[k]
                                    for k, v in loss_dict_reduced.items() if k in weight_dict}
        losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
 
        loss_value = losses_reduced_scaled.item()
 
        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)
 
        optimizer.zero_grad()
        losses.backward()
        if max_norm > 0:
            grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        else:
            grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm)
        optimizer.step()
        
        if wandb is not None:
            wandb.log({"total_loss":loss_value})
            wandb.log(loss_dict_reduced_scaled)
            wandb.log(loss_dict_reduced_unscaled)
 
        metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        metric_logger.update(grad_norm=grad_total_norm)
        
        samples, targets = prefetcher.next()
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

## ORIGINAL FUNCTION
@torch.no_grad()
def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir, args):
    model.eval()
    criterion.eval()
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Test:'
    iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
    coco_evaluator = OWEvaluator(base_ds, iou_types, args=args)
    class_name = data_loader.dataset.CLASS_NAMES
    dataset_name = args.test_set
    panoptic_evaluator = None
    if 'panoptic' in postprocessors.keys():
        panoptic_evaluator = PanopticEvaluator(
            data_loader.dataset.ann_file,
            data_loader.dataset.ann_folder,
            output_dir=os.path.join(output_dir, "panoptic_eval"),
        )
    final_output_list = []
    for samples, targets in metric_logger.log_every(data_loader, 10, header):
        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        outputs = model(samples)

        orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
        results = postprocessors['bbox'](outputs, orig_target_sizes)
        # save into json file
        for i in range(len(targets)):
            img_results = []
            img_id = int(ord_to_chr(targets[i]["image_id"]))
            scores = results[i]['scores']
            boxes = results[i]['boxes']
            classes = results[i]['labels']

            boxes = box_xyxy_to_xywh(boxes).cpu().numpy().tolist()
            scores = scores.cpu().tolist()
            classes = classes.cpu().tolist()

            for k in range(len(scores)):
                result = {
                    "image_id": img_id,
                    "category_id": classes,
                    # "name": class_name[classes[k]],
                    "bbox": boxes[k],
                    "score": scores[k]
                }
                img_results.append(result)
            final_output_list.extend(img_results)

        if 'segm' in postprocessors.keys():
            target_sizes = torch.stack([t["size"] for t in targets], dim=0)
            results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)
        res = {ord_to_chr(target['image_id']): output for target, output in zip(targets, results)}
        if coco_evaluator is not None:
            coco_evaluator.update(res)
 
        if panoptic_evaluator is not None:
            res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes)
            for i, target in enumerate(targets):
                image_id = target["image_id"].item()
                file_name = f"{image_id:012d}.png"
                res_pano[i]["image_id"] = image_id
                res_pano[i]["file_name"] = file_name
 
            panoptic_evaluator.update(res_pano)
    
    with open(os.path.join('json_output_dir', 'OW-DETR-50e', '{}_instances_results.json'.format("OW-DETR-swinT-coda")), 'w') as fp:
        json.dump(final_output_list, fp, indent=4, separators=(',', ': '))
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    # print("Averaged stats:", metric_logger)
    if coco_evaluator is not None:
        coco_evaluator.synchronize_between_processes()
    if panoptic_evaluator is not None:
        panoptic_evaluator.synchronize_between_processes()
    # accumulate predictions from all images
    if coco_evaluator is not None:
        coco_evaluator.accumulate()
        res = coco_evaluator.summarize()
    panoptic_res = None
    if panoptic_evaluator is not None:
        panoptic_res = panoptic_evaluator.summarize()
    stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
    stats['metrics']=res
    if coco_evaluator is not None:
        if 'bbox' in postprocessors.keys():
            stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
        if 'segm' in postprocessors.keys():
            stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist()
    if panoptic_res is not None:
        stats['PQ_all'] = panoptic_res["All"]
        stats['PQ_th'] = panoptic_res["Things"]
        stats['PQ_st'] = panoptic_res["Stuff"]
    return stats, coco_evaluator
 
    
@torch.no_grad()
def get_exemplar_replay(model, exemplar_selection, device, data_loader):
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = '[ExempReplay]'
    print_freq = 10
    prefetcher = data_prefetcher(data_loader, device, prefetch=True)
    samples, targets = prefetcher.next()
    image_sorted_scores_reduced={}
    for _ in metric_logger.log_every(range(len(data_loader)), print_freq, header):
        outputs = model(samples)
        image_sorted_scores = exemplar_selection(samples, outputs, targets)
        for i in utils.combine_dict(image_sorted_scores):
            image_sorted_scores_reduced.update(i[0])
            
        metric_logger.update(loss=len(image_sorted_scores_reduced.keys()))
        samples, targets = prefetcher.next()
        
    print(f'found a total of {len(image_sorted_scores_reduced.keys())} images')
    return image_sorted_scores_reduced

def box_xyxy_to_xywh(x):
    x0, y0, x1, y1 = x.unbind(-1)
    b = [x0, y0,
         (x1 - x0), (y1 - y0)]
    return torch.stack(b, dim=-1)

def ord_to_chr(image_id):
    sl = []
    for ord_id in image_id:
        sl.append(chr(int(ord_id)))
    return ''.join(sl)

def nms(bboxes, scores, thresh=0.5):
    # 利用Pytorch实现NMS算法
    x1 = bboxes[:, 0]
    y1 = bboxes[:, 1]
    x2 = bboxes[:, 2]
    y2 = bboxes[:, 3]
    # 计算每个box的面积
    areas = (x2 - x1 + 1) * (y2 - y1 + 1)
    # 对得分降序排列，order为索引
    _, order = scores.sort(0, descending=True)
    # keep保留了NMS后留下的边框box
    keep = []
    while order.numel() > 0:
        if order.numel() == 1: # 保留框只剩1个
            i = order.item()
            keep.append(i)
            break
        else: # 还有保留框没有NMS
            i = order[0].item() # 保留scores最大的那个框box[i]
            keep.append(i)
        # 利用tensor.clamp函数求取每个框和当前框的最大值和最小值
        # 将输入input张量每个元素的夹紧到区间 [min,max]，并返回结果到一个新张量
        xx1 = x1[order[1: ]].clamp(min=x1[i]) 
        # 左坐标夹紧的最小值为order中scores最大的框的左坐标，对剩余所有order元素进行夹紧操作
        yy1 = y1[order[1: ]].clamp(min=y1[i]) 
        xx2 = x2[order[1: ]].clamp(max=x2[i]) 
        yy2 = y2[order[1: ]].clamp(max=y2[i]) 
        # 求每一个框和当前框重合部分和总共叠加的面积
        inter = (xx2 - xx1).clamp(min=0) * (yy2 - yy1).clamp(min=0)
        union = areas[i] + areas[order[1: ]] - inter
        # 计算每一个框和当前框的IoU
        IoU = inter / union
        # 保留IoU小于threshold的边框索引
        idx = (IoU <= thresh).nonzero().squeeze()
        if idx.numel() == 0:
            break
        # 这里+1是为了补充idx和order之间的索引差
        order = order[idx+1]
    # 返回保留下的所有边框的索引
    return torch.LongTensor(keep)

@torch.no_grad()
def viz(model, criterion, postprocessors, data_loader, base_ds, device, output_dir):
    known_viz_thre = 0.3
    unknown_viz_thre = 0.6
    import numpy as np
    os.makedirs(output_dir, exist_ok=True)
    model.eval()
    criterion.eval()
 
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))

    for samples, targets in data_loader:
        samples = samples.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        outputs = model(samples)
        # target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
        image_size = samples.tensors.shape
        w, h = image_size[-2], image_size[-1]
        target_sizes = torch.zeros(1, 2).to(outputs['pred_logits'].device)
        target_sizes[0][0] = w
        target_sizes[0][1] = h

        results = postprocessors['bbox'](outputs, target_sizes)[0]
        predictied_scores = results['scores']
        predictied_boxes = results['boxes']
        predictied_labels = results['labels']
        known_mask = predictied_labels != 6
        known_viz_mask = predictied_scores[known_mask] > known_viz_thre
        k_predictied_scores = predictied_scores[known_mask][known_viz_mask]
        k_predictied_boxes = predictied_boxes[known_mask][known_viz_mask]
        k_predictied_labels = predictied_labels[known_mask][known_viz_mask]

        unknown_viz_mask = predictied_scores[~known_mask] > unknown_viz_thre
        uk_predictied_scores = predictied_scores[~known_mask][unknown_viz_mask]
        uk_predictied_boxes = predictied_boxes[~known_mask][unknown_viz_mask]
        uk_predictied_labels = predictied_labels[~known_mask][unknown_viz_mask]

        # # remove known boxes overlapping with known boxes
        # if len(k_predictied_boxes)!=0:
        #     ss_iou, _ = jaccard(uk_predictied_boxes, k_predictied_boxes).max(dim=1)
        #     uk_predictied_boxes = uk_predictied_boxes[ss_iou < 0.3] # no overlap with known predictions
        #     uk_predictied_scores = uk_predictied_scores[ss_iou < 0.3]
        #     uk_predictied_labels = uk_predictied_labels[ss_iou < 0.3]

        # remove boxes overlapping with themselves
        predictied_scores = torch.cat([k_predictied_scores, uk_predictied_scores])
        predictied_boxes = torch.cat([k_predictied_boxes, uk_predictied_boxes])
        predictied_labels = torch.cat([k_predictied_labels, uk_predictied_labels])

        nms_pick = nms(predictied_boxes, predictied_scores) 
        predictied_boxes = predictied_boxes[nms_pick,:] # nms
        predictied_scores = predictied_scores[nms_pick]
        predictied_labels = predictied_labels[nms_pick]

        fig, ax = plt.subplots(1, 2, figsize=(10,3), dpi=200)

        # Known pred results
        plot_prediction(samples.tensors[0:1], predictied_scores, predictied_boxes, predictied_labels, ax[0], plot_prob=False)
        ax[0].set_title('Prediction (Ours)')
        # GT Results
        plot_prediction_GT(samples.tensors[0:1], targets[0]['boxes'], targets[0]['labels'], ax[1], plot_prob=False)
        ax[1].set_title('GT')
 
        for i in range(2):
            ax[i].set_aspect('equal')
            ax[i].set_axis_off()
        chr_image_id = ord_to_chr(targets[0]["image_id"])
        plt.savefig(os.path.join(output_dir, f'img_{chr_image_id}.jpg'))