import csv
import os
from typing import Dict, List

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

from utils import Config

from .base_evaluator import BaseEvaluator
from sklearn import metrics

import cv2 

class SP_Evaluator(BaseEvaluator):
    def __init__(self, config: Config):
        super(SP_Evaluator, self).__init__(config)
        self.id_pred = None
        self.id_conf = None
        self.id_gt = None
    
    # def auc_simple(self, logit, label):
    #     auroc = metrics.roc_auc_score(label, logit, pos_label=0)
    #     return auroc

    def auc_simple(self, logit, label):
    # Invert labels and scores to treat 0 as positive
        inverted_label = 1 - label
        inverted_logit = 1 - logit
        auroc = metrics.roc_auc_score(inverted_label, inverted_logit)
        return auroc

        
    def eval_acc_SP(self, nets, data_loader_split,
                 postprocessor: BaseEvaluator = None,
                 epoch_idx: int = -1, batch_idx: int = -1):

        for net_name in nets:

            if type(nets[net_name]) is dict:
                nets[net_name]['backbone'].eval()
            else:
                nets[net_name].eval()

        metrics = {}
        metrics['epoch_idx'] = epoch_idx
        metrics['batch_idx'] = batch_idx

        loss_sum = 0.0
        acc_sum = 0.0
        loss_num = 0
        acc_num = 0

        auc_sum = 0
        auc_num = 0

        metrics['Real'], logit_real, gt_real = self._eval_acc_Real(nets, data_loader_split['Real'], postprocessor, detailed_return=True)

        for set in data_loader_split:

            if set == 'Real':
                metrics[set]['auc'] = None
                continue

            metrics[set] = self._eval_acc_SP(nets,
                        data_loader_split[set],
                        logit_real, gt_real, 
                        postprocessor)
            
            loss_sum_set = 0.0
            acc_sum_set = 0.0
            loss_num_set = 0
            acc_num_set = 0
            auc_sum_set = 0
            for subset_name in metrics[set]:
                loss_sum_set += metrics[set][subset_name]['loss'] * metrics[set][subset_name]['num_loss']
                acc_sum_set += metrics[set][subset_name]['acc'] * metrics[set][subset_name]['num_acc']

                loss_num_set += metrics[set][subset_name]['num_loss']
                acc_num_set += metrics[set][subset_name]['num_acc']

                auc_sum_set += metrics[set][subset_name]['auc']

            auc_num_set = len(metrics[set])
            
            metrics[set]['loss'] = loss_sum_set / loss_num_set if loss_num_set !=0 else 0
            metrics[set]['acc'] = acc_sum_set / acc_num_set if acc_num_set !=0 else 0
            metrics[set]['auc'] = auc_sum_set / auc_num_set if auc_num_set !=0 else 0

            loss_sum += loss_sum_set
            acc_sum += acc_sum_set

            loss_num += loss_num_set
            acc_num += acc_num_set

            auc_sum += metrics[set]['auc']
            auc_num += 1

        metrics['loss'] = loss_sum / loss_num
        metrics['acc'] = (acc_sum / acc_num) * 0.5 + metrics['Real']['acc'] * 0.5
        metrics['auc'] = auc_sum / auc_num

        return metrics
    
    def _eval_acc_SP(self, nets, data_loader_set, logit_real, gt_real,
                 postprocessor: BaseEvaluator = None):
        logit_all = []
        gt_all = []

        metrics_set = {}
        for subset_name, data_loader_subset in data_loader_set.items():
            metrics_set[subset_name] = {}
            pred_subset, logit_subset, gt_subset, loss_subset  = postprocessor.inference(nets, data_loader_subset, 'loss',cat = subset_name)
            
            logit_all.append(logit_subset[:, 1])  # probability for class 1
            gt_all.append(gt_subset)

            metrics_set[subset_name]['acc'] = sum(pred_subset == gt_subset) / len(gt_subset)
            loss_sum = sum(loss_subset)
            loss_num = len(data_loader_subset) # batch num
            metrics_set[subset_name]['num_acc'] = len(gt_subset)
            metrics_set[subset_name]['loss'] = loss_sum / loss_num
            metrics_set[subset_name]['num_loss'] = loss_num
            print(f"Subset: {subset_name} tpr: {sum(pred_subset == gt_subset)}")
            pred_p = np.concatenate((logit_real[:,1], logit_subset[:,1]), axis=0)
            gt_cat = np.concatenate((gt_real, gt_subset), axis=0)
            metrics_set[subset_name]['auc'] = self.auc_simple(pred_p, gt_cat)
            
            

        logit_all = np.concatenate(logit_all, axis=0)
        gt_all = np.concatenate(gt_all, axis=0)

        # Also include real samples
        pred_p = np.concatenate((logit_real[:, 1], logit_all), axis=0)
        gt_cat = np.concatenate((gt_real, gt_all), axis=0)
        # draw_roc_curve(abs(1-gt_cat), abs(1-pred_p))
        return metrics_set
    
    def _eval_acc_Real(self, nets, data_loader_set,
                 postprocessor: BaseEvaluator = None,
                 detailed_return: bool = False):
        metrics_set = {}
        pred_set, logit_set, gt_set, loss_set  = postprocessor.inference(nets, data_loader_set, 'loss', cat = 'real')
        metrics_set['acc'] = sum(pred_set == gt_set) / len(gt_set)
        loss_sum = sum(loss_set)
        loss_num = len(data_loader_set) # batch num
        metrics_set['loss'] = loss_sum / loss_num
        print(f"Subset: Real tpr: {sum(pred_set == gt_set)}")
        if not detailed_return:
            return metrics_set
        else:
            return metrics_set, logit_set, gt_set

import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
import numpy as np

def draw_roc_curve(y_true, y_scores, threshold_to_mark=0.5):
    # Compute ROC curve and AUC
    fpr, tpr, thresholds = roc_curve(y_true, y_scores)
    roc_auc = auc(fpr, tpr)

    threshold_diffs = np.abs(thresholds - threshold_to_mark)
    idx_05 = np.argmin(threshold_diffs)
    fpr_05 = fpr[idx_05]
    tpr_05 = tpr[idx_05]

    
    youden_j = tpr - fpr
    idx_best = np.argmax(youden_j)
    fpr_best = fpr[idx_best]
    tpr_best = tpr[idx_best]
    threshold_best = thresholds[idx_best]


    plt.figure()
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=1.5, linestyle='--')


    plt.plot(fpr_05, tpr_05, 'ro', label=f'Threshold = 0.5')
    plt.annotate(f'Th=0.5\nTPR={tpr_05:.4f}\nFPR={fpr_05:.4f}',
                 (fpr_05, tpr_05),
                 textcoords="offset points", xytext=(10, -25), ha='left',
                 fontsize=9, bbox=dict(boxstyle="round", fc="white", ec="gray"))


    plt.plot(fpr_best, tpr_best, 'go', label=f'Best Threshold = {threshold_best:.4f}')
    plt.annotate(f'Th={threshold_best:.4f}\nTPR={tpr_best:.4f}\nFPR={fpr_best:.4f}',
                 (fpr_best, tpr_best),
                 textcoords="offset points", xytext=(10, 10), ha='left',
                 fontsize=9, bbox=dict(boxstyle="round", fc="white", ec="gray"))

    # Labels
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve with Thresholds')
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig("roc_curve_with_thresholds.png")
    plt.show()


    print(f"Threshold = 0.5: TPR = {tpr_05:.3f}, FPR = {fpr_05:.3f}")
    print(f"Best threshold = {threshold_best:.3f}: TPR = {tpr_best:.3f}, FPR = {fpr_best:.3f}")
