import torch.distributed as dist
from abc import ABC, abstractmethod
import json
import os
from tqdm import tqdm

import torch
from torchvision.ops import box_convert

import contextlib
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval

from .utils.general_utils import format_output
from .utils.coco_utils import convert_results_to_coco

import matplotlib.pyplot as plt


class Evaluator(ABC):
    def __init__(self, model, device, distributed=False, rank=0, world_size=1):
        self.model = model
        self.device = device
        self.distributed = distributed
        self.rank = rank
        self.world_size = world_size

    def __get_predictions__(self, loader, current_box_format, required_box_format, remove_bd=False):

        results = []
        self.model.eval()
        if self.model.training:
            self.model.training = False

        i = 0
        for img, target, _ in loader:

            # Move images to device
            img = [value.to(self.device) for value in img]

            with torch.inference_mode():
                outputs = self.model(img)

            if len(outputs) != len(target):
                raise ValueError("Outputs and targets have different batch sizes!")

            try:
                result = format_output(outputs, target, img, current_box_format, required_box_format, remove_bd=remove_bd)
            except Exception as e:
                print(f"Error formatting output: {e}")
                print(f"Outputs: {outputs}")
                print(f"Targets: {target}")
                print(f"Images: {[image.shape for image in img]}")
                print(f"Current box format: {current_box_format}, Required box format: {required_box_format}")
                print(f"Batch size: {len(outputs)}")
                print(f'Model mode: {self.model.training}')
                raise e
            
            results.extend(result)
            del outputs, img, target
            torch.cuda.empty_cache()

        if self.distributed:
            dist.barrier()

            if self.rank == 0:
                result_list = [None] * self.world_size
                dist.gather_object(results, result_list, dst=0)
                result_list = [item for sublist in result_list for item in sublist]
            else:
                dist.gather_object(results, dst=0)
                result_list = None

            dist.barrier()

        return results
    
    @abstractmethod
    def evaluate(self, loader, output_format, save_path):
        """
        Evaluate the model on the given data loader and save the results.
        
        Args:
            loader: DataLoader for the dataset to evaluate.
            output_format: Format for the output bounding boxes (e.g., "coco", "pascal_voc").
            save_path: Path to save the evaluation results.
        """
        pass

    def get_prediction_sample(self, loader, num_preds, save_path, current_box_format, remove_bd=False):

        # Put the model in evaluation mode
        self.model.eval()

        num_preds_made = 0
        for img, target, _ in loader:
            
            if num_preds_made >= num_preds:
                break
            
            # Move images to device
            img = [value.to(self.device) for value in img]

            with torch.no_grad():
                outputs = self.model(img)

            if len(outputs) != len(target):
                raise ValueError("Outputs and targets have different batch sizes!")

            result = format_output(outputs, target, img, current_box_format, 'xywh', remove_bd=remove_bd)

            # Plot the images and its predictions
            for i in range(len(result)):

                pred_boxes = result[i]["pred_boxes"]
                pred_labels = result[i]["pred_labels"]
                pred_scores = result[i]["pred_scores"]
                gt_boxes = result[i]["gt_boxes"]
                gt_labels = result[i]["gt_labels"]

                # Create a figure and axis
                fig, ax = plt.subplots(1, figsize=(12, 8))
                ax.imshow(img[i].cpu().permute(1, 2, 0).numpy())

                fig_2, ax_2 = plt.subplots(1, figsize=(12, 8))
                ax_2.imshow(img[i].cpu().permute(1, 2, 0).numpy())

                cmap = plt.get_cmap("tab20")

                # Plot the predicted bounding boxes
                for box, label, score in zip(pred_boxes, pred_labels, pred_scores):
                    x_min, y_min, width, height = box

                    color = cmap(label % 20)  # Use modulo to cycle through colors
                    rect = plt.Rectangle((x_min, y_min), width, height, linewidth=2,
                                            edgecolor=color, facecolor='none')
                    ax.add_patch(rect)
                    ax.text(x_min, y_min - 5, f'{label} {score:.2f}', color='white', fontsize=12,
                            bbox= dict(facecolor=color, alpha=0.5, edgecolor='none'))

                # Save the figure
                print(f"Saving prediction {num_preds_made + 1}/{num_preds} to {save_path}")
                fig.savefig(os.path.join(save_path, f"pred_{num_preds_made}.png"))
                plt.close(fig)

                # Plot the ground truth bounding boxes
                for box, label in zip(gt_boxes, gt_labels):
                    x_min, y_min, width, height = box

                    color = cmap(label % 20)  # Use modulo to cycle through colors
                    rect = plt.Rectangle((x_min, y_min), width, height, linewidth=2,
                                            edgecolor=color, facecolor='none')
                    
                    ax_2.add_patch(rect)
                    ax_2.text(x_min, y_min - 5, f'{label}', color='white', fontsize=12,
                            bbox= dict(facecolor=color, alpha=0.5, edgecolor='none'))
                    
                # Save the figure with ground truth
                fig_2.savefig(os.path.join(save_path, f"gt_{num_preds_made}.png"))
                plt.close(fig_2)

                num_preds_made += 1

                if num_preds_made >= num_preds:
                    break

class MTSDEvaluator(Evaluator):

    def __get_results__(self, results, save_path):

        ground_truth_coco, predictions_coco = convert_results_to_coco(results)
    
        # If predictions are empty, return all zeros
        if len(predictions_coco) == 0:
            return {
                'ap50:95': 0.0,
                'ap50': 0.0,
                'ap75': 0.0,
                'ap50:95_small': 0.0,
                'ap50:95_medium': 0.0,
                'ap50:95_large': 0.0,
                'ar50:95': 0.0,
                'ar50': 0.0,
                'ar75': 0.0,
                'ar50:95_small': 0.0,
                'ar50:95_medium': 0.0,
                'ar50:95_large': 0.0
            }

        # Save ground truth and predictions to disk
        with open(os.path.join(save_path, "ground_truth.json"), "w") as f:
            json.dump(ground_truth_coco, f)

        with open(os.path.join(save_path, "predictions.json"), "w") as f:
            json.dump(predictions_coco, f)

        # To suppress prints from evaluate/accumulate/summarize
        with open(os.devnull, 'w') as devnull:
            with contextlib.redirect_stdout(devnull):
                coco_gt = COCO(os.path.join(save_path, "ground_truth.json"))
                coco_gt.dataset['info'] = {
                    'description': 'COCO Evaluation',
                }
                
                coco_dt = coco_gt.loadRes(os.path.join(save_path, "predictions.json"))
                coco_eval = COCOeval(coco_gt, coco_dt, "bbox")
                #coco_eval.params.catIds = [cid for cid in coco_gt.getCatIds() if cid != 1] # Exclude the background category (ID 1 as this is other-sign)
                coco_eval.evaluate()
                coco_eval.accumulate()
                coco_eval.summarize()

        # Remove the temporary files
        os.remove(os.path.join(save_path, "ground_truth.json"))
        os.remove(os.path.join(save_path, "predictions.json"))

        stats = coco_eval.stats
        
        metrics = {
            'ap50:95': stats[0],
            'ap50': stats[1],
            'ap75': stats[2],
            'ap50:95_small': stats[3],
            'ap50:95_medium': stats[4],
            'ap50:95_large': stats[5],
            'ar50:95': stats[6],
            'ar50': stats[7],
            'ar75': stats[8],
            'ar50:95_small': stats[9],
            'ar50:95_medium': stats[10],
            'ar50:95_large': stats[11]
        }

        return metrics

    def evaluate(self, loader, current_box_format, save_path, remove_bd=False):

        # Put the model in evaluation mode
        self.model.eval()
        predictons = self.__get_predictions__(loader, current_box_format, "xywh", remove_bd=remove_bd)

        if not self.distributed or (self.distributed and self.rank == 0):

            # Make save_path if it does not exist
            if not os.path.exists(save_path):
                os.makedirs(save_path)

            metrics = self.__get_results__(predictons, save_path)
        else:
            metrics = None

        if self.distributed:
            dist.barrier()
            torch.cuda.empty_cache()

        elif self.device.type == "cuda":
            torch.cuda.empty_cache()
        
        return metrics

class COCOEvaluator(Evaluator):
    
    def __get_results__(self, results, save_path):

        ground_truth_coco, predictions_coco = convert_results_to_coco(results)
    
        # If predictions are empty, return all zeros
        if len(predictions_coco) == 0:
            return {
                'ap50:95': 0.0,
                'ap50': 0.0,
                'ap75': 0.0,
                'ap50:95_small': 0.0,
                'ap50:95_medium': 0.0,
                'ap50:95_large': 0.0,
                'ar50:95': 0.0,
                'ar50': 0.0,
                'ar75': 0.0,
                'ar50:95_small': 0.0,
                'ar50:95_medium': 0.0,
                'ar50:95_large': 0.0
            }

        # Save ground truth and predictions to disk
        with open(os.path.join(save_path, "ground_truth.json"), "w") as f:
            json.dump(ground_truth_coco, f)

        with open(os.path.join(save_path, "predictions.json"), "w") as f:
            json.dump(predictions_coco, f)

        # To suppress prints from evaluate/accumulate/summarize
        with open(os.devnull, 'w') as devnull:
            with contextlib.redirect_stdout(devnull):
                coco_gt = COCO(os.path.join(save_path, "ground_truth.json"))
                coco_gt.dataset['info'] = {
                    'description': 'COCO Evaluation',
                }
                
                coco_dt = coco_gt.loadRes(os.path.join(save_path, "predictions.json"))
                coco_eval = COCOeval(coco_gt, coco_dt, "bbox")
                coco_eval.evaluate()
                coco_eval.accumulate()
                coco_eval.summarize()

        # Remove the temporary files
        os.remove(os.path.join(save_path, "ground_truth.json"))
        os.remove(os.path.join(save_path, "predictions.json"))

        stats = coco_eval.stats
        
        metrics = {
            'ap50:95': stats[0],
            'ap50': stats[1],
            'ap75': stats[2],
            'ap50:95_small': stats[3],
            'ap50:95_medium': stats[4],
            'ap50:95_large': stats[5],
            'ar50:95': stats[6],
            'ar50': stats[7],
            'ar75': stats[8],
            'ar50:95_small': stats[9],
            'ar50:95_medium': stats[10],
            'ar50:95_large': stats[11]
        }

        return metrics

    def evaluate(self, loader, current_box_format, save_path, remove_bd=False):

        # Put the model in evaluation mode
        self.model.eval()
        predictons = self.__get_predictions__(loader, current_box_format, "xywh", remove_bd=remove_bd)

        if not self.distributed or (self.distributed and self.rank == 0):

            # Make save_path if it does not exist
            if not os.path.exists(save_path):
                os.makedirs(save_path)

            metrics = self.__get_results__(predictons, save_path)
        else:
            metrics = None

        if self.distributed:
            dist.barrier()
            torch.cuda.empty_cache()

        elif self.device.type == "cuda":
            torch.cuda.empty_cache()
        
        return metrics