import torch

from ..evaluator import Evaluator
from ..utils.bd_utils import bd_max_iou, plot_iou_scores, get_ra_asr_matrix_oda, plot_asr_ra_heat
import torchvision
import os

import pandas as pd

class ODAEvaluator(Evaluator):

    def __init__(self, model, model_name, device, is_test=False, is_multi=False, distributed=False, rank=0, world_size=1):
        super().__init__(model, device, distributed, rank, world_size)
        self.is_multi = is_multi
        self.is_test = is_test
        self.model_name = model_name

    def __get_results__(self, model_score_threshold, results):

        results_tensor = bd_max_iou(results, model_score_threshold, self.device, multi_trigger=self.is_multi)
        iou_range = torch.arange(0.5, 1.0, 0.05, device=self.device)

        ra_list, asr_list = [], []

        # columns: 0 - area, 1 - best_org
        for iou_thr in iou_range:

            tmp_results_tensor = torch.clone(results_tensor)

            if len(tmp_results_tensor) == 0:
                ra_list.append(0.0)
                asr_list.append(1.0)
                continue

            # RA: Number of times correct label is detected
            # ASR: Number of times no box is detected
            best_org = results_tensor[:, 0]           # only original‑class IoU
            ra  = (best_org >= iou_thr).float().mean()
            asr = 1.0 - ra

            ra_list.append(ra.item())
            asr_list.append(asr.item())
            
        avg_ra = sum(ra_list) / len(ra_list)
        avg_asr = sum(asr_list) / len(asr_list)

        index_75 = iou_range.tolist().index(0.75)
        ra_50, ra_75 = ra_list[0], ra_list[index_75]
        asr_50, asr_75 = asr_list[0], asr_list[index_75]

        metrics = {
            "avg_asr": avg_asr,
            "avg_ra": avg_ra,
            "asr_50": asr_50,
            "asr_75": asr_75,
            "ra_50": ra_50,
            "ra_75": ra_75
        }

        return metrics
    
    def evaluate(self, loader, current_box_format, save_path, current_epoch):

        # Make save_path if it does not exist
        if (self.distributed and self.rank == 0) or not self.distributed:
            if not os.path.exists(save_path):
                os.makedirs(save_path)

        if self.distributed:
            torch.distributed.barrier()

        self.model.eval()

        # Set the score threshold used to post process predictions to be 0.05
        if self.model_name == 'faster_rcnn':
            saved_score_thresh = self.model.roi_heads.score_thresh
            self.model.roi_heads.score_thresh = 0.05
        elif self.model_name == 'fcos':
            saved_score_thresh = self.model.score_thresh
            self.model.score_thresh = 0.05
        elif self.model_name == 'dino':
            saved_score_thresh = self.model.postprocessors['bbox'].score_threshold
            self.model.postprocessors['bbox'].score_threshold = 0.05
        elif self.model_name == 'yolo':
            saved_score_thresh = self.model.conf_thres
            self.model.conf_thres = 0.05
        else:
            raise NotImplementedError(f'Model {self.model_name} is not supported for ODA evaluation.')


        predictions = self.__get_predictions__(loader, current_box_format, "xyxy", remove_bd=False)

        if (self.distributed and self.rank == 0) or not self.distributed:
            metrics = self.__get_results__(saved_score_thresh, predictions)

            # plot_save_path = os.path.join(save_path, "plots")

            # if not os.path.exists(plot_save_path):
            #     os.makedirs(plot_save_path)
            
            # plot_iou_scores(predictions, self.device, plot_save_path, current_epoch)

            # ra_matrix, asr_matrix, conf_th, iou_th = get_ra_asr_matrix_oda(predictions, conf_th=torch.arange(0.00, 1.01, 0.05), iou_th=torch.arange(0.00, 1.01, 0.05))
            # plot_asr_ra_heat(ra_matrix, asr_matrix, conf_th, iou_th, plot_save_path, current_epoch=current_epoch, title="ASR vs RA at epoch {}".format(current_epoch))

            # # If is_test is True, save the ra_matrix and asr_matrix as a csv file
            # if self.is_test:

            #     ra_matrix_df = pd.DataFrame(ra_matrix, index=conf_th, columns=iou_th)
            #     asr_matrix_df = pd.DataFrame(asr_matrix, index=conf_th, columns=iou_th)

            #     ra_matrix_df.to_csv(os.path.join(save_path, "ra_matrix_epoch_{}.csv".format(current_epoch)))
            #     asr_matrix_df.to_csv(os.path.join(save_path, "asr_matrix_epoch_{}.csv".format(current_epoch)))
        else:
            metrics = None

        if self.distributed:
            torch.distributed.barrier()
            torch.cuda.empty_cache()
        elif self.device.type == "cuda":
            torch.cuda.empty_cache()

        # Restore the original score threshold
        if self.model_name == 'faster_rcnn':
            self.model.roi_heads.score_thresh = saved_score_thresh
        elif self.model_name == 'fcos':
            self.model.score_thresh = saved_score_thresh
        elif self.model_name == 'dino':
            self.model.postprocessors['bbox'].score_threshold = saved_score_thresh
        elif self.model_name == 'yolo':
            self.model.conf_thres = saved_score_thresh
        else:
            raise NotImplementedError(f'Model {self.model_name} is not supported for ODA evaluation.')

        return metrics