import os
import torch
import traceback
import numpy as np

from sklearn import metrics as skl_metrics
from xad.utils.logger import ROC
from xad.counterfactual.pytorch_fid.fid_score import calculate_fid_given_paths


def get_anomaly_scores(class_labels: torch.Tensor,
                       orig_anomaly_scores: torch.Tensor,
                       counterfactual_anomaly_scores: torch.Tensor = None):
    anomaly_class_indices = torch.nonzero(class_labels).squeeze()
    normal_class_indices = (class_labels == 0).nonzero().squeeze()
    orig_anomaly_scores = orig_anomaly_scores.squeeze()
    ascores_true_normal = torch.index_select(orig_anomaly_scores, 0,
                                             normal_class_indices)
    ascores_true_anomaly = torch.index_select(orig_anomaly_scores, 0,
                                              anomaly_class_indices)
    ascores_counterfactuals_of_normal = torch.index_select(
        counterfactual_anomaly_scores, 0,
        normal_class_indices
    )
    return ascores_true_normal, ascores_true_anomaly, ascores_counterfactuals_of_normal


def get_roc(ascores_normal: torch.Tensor, ascore_anomalous: torch.Tensor) -> ROC:
    scores = torch.cat((ascores_normal, ascore_anomalous), 0)
    label_normal = torch.zeros_like(ascores_normal)
    label_counterfact_normal = torch.ones_like(ascore_anomalous)
    labels = torch.cat((label_normal, label_counterfact_normal), 0)
    fpr, tpr, thresholds = skl_metrics.roc_curve(labels, scores)
    auc = skl_metrics.auc(fpr, tpr)
    roc = ROC(tpr, fpr, thresholds, auc)
    return roc



def compute_fid_scores(path_actual_imgs: str | os.PathLike,
                       path_counterfactual_imgs: str | os.PathLike,
                       path_anamalous_imgs: str | os.PathLike,
                       path_test_subset_1: str | os.PathLike,
                       path_test_subset_2: str | os.PathLike,
                       device: torch.device,
                       xtrainer,
                       cstr:str,
                       seed):
    """Compute the desired fid scores using the paths of the stored images
    
    Parameters
    ----------
    path_actual_imgs: str
        Path to the folder where normal test images are stored
    path_counterfactual_imgs: str
        Path to the folder where counterfactuals are stored
    path_test_subset_1:
        Path to the folder where a random half of the test set is stored
    path_test_subset_2:
        Path to the folder where the other random half of the test set is stored
    
    Returns
    -------
    list
        a list of the fid score values
    """
    try:
        fid_score = calculate_fid_given_paths(
            path1=path_actual_imgs, path2=path_counterfactual_imgs, batch_size=50, device=device, dims=2048, num_workers=4
        )
        fid_score_upper_bound = calculate_fid_given_paths(
            path1=path_actual_imgs, path2=path_anamalous_imgs, batch_size=50, device=device, dims=2048, num_workers=4
        )
        fid_score_lower_bound = calculate_fid_given_paths(
            path1=path_test_subset_1, path2=path_test_subset_2, batch_size=50, device=device, dims=2048, num_workers=4
        )
        
    except ValueError as err:
        try:
            xtrainer.logger.warning(
                f"ValueError when computing FID score for {cstr} (seed {seed}) with 2048 dimensions: "
                f"\n{''.join(traceback.format_exception(err.__class__, err, err.__traceback__))}"
            )
            xtrainer.logger.print("Recomputing FID scores with reducing the feature dimension to 768")
            fid_score = calculate_fid_given_paths(
                path1=path_counterfactual_imgs, path2=path_actual_imgs, batch_size=50, device=device, dims=768, num_workers=4
            )
            fid_score_upper_bound = calculate_fid_given_paths(
                path1=path_anamalous_imgs, path2=path_actual_imgs, batch_size=50, device=device, dims=768, num_workers=4
            )
            fid_score_lower_bound = calculate_fid_given_paths(
                path1=path_test_subset_1, path2=path_test_subset_2, batch_size=50, device=device, dims=768, num_workers=4
            )
        except ValueError as verr:
            try:
                xtrainer.logger.warning(
                    f"ValueError when computing FID score for {cstr} (seed {seed}) with 768 dimensions: "
                    f"\n{''.join(traceback.format_exception(verr.__class__, verr, verr.__traceback__))}"
                )
                xtrainer.logger.print("Recomputing FID scores with reducing the feature dimension to 192")
                fid_score = calculate_fid_given_paths(
                    path1=path_counterfactual_imgs, path2=path_actual_imgs, batch_size=50, device=device, dims=192, num_workers=4
                )
                fid_score_upper_bound = calculate_fid_given_paths(
                    path1=path_anamalous_imgs, path2=path_actual_imgs, batch_size=50, device=device, dims=192, num_workers=4
                )
                fid_score_lower_bound = calculate_fid_given_paths(
                    path1=path_test_subset_1, path2=path_test_subset_2, batch_size=50, device=device, dims=192, num_workers=4
                )
            except ValueError as verr2:
                try:
                    xtrainer.logger.warning(
                        f"ValueError when computing FID score for {cstr} (seed {seed}) with 192 dimensions: "
                        f"\n{''.join(traceback.format_exception(verr2.__class__, verr2, verr2.__traceback__))}"
                    )
                    xtrainer.logger.print("Recomputing FID scores with reducing the feature dimension to 64")
                    fid_score = calculate_fid_given_paths(
                        path1=path_counterfactual_imgs, path2=path_actual_imgs, batch_size=50, device=device, dims=64,
                        num_workers=4
                    )
                    fid_score_upper_bound = calculate_fid_given_paths(
                        path1=path_anamalous_imgs, path2=path_actual_imgs, batch_size=50, device=device, dims=64, num_workers=4
                    )
                    fid_score_lower_bound = calculate_fid_given_paths(
                        path1=path_test_subset_1, path2=path_test_subset_2, batch_size=50, device=device, dims=64, num_workers=4
                    )
                except ValueError as verr3:
                    fid_score = np.nan
                    fid_score_upper_bound = np.nan
                    fid_score_lower_bound = np.nan
                    xtrainer.logger.warning(
                        f"ValueError when computing FID score for {cstr} (seed {seed}) with 64 dimensions: "
                        f"\n{''.join(traceback.format_exception(verr3.__class__, verr3, verr3.__traceback__))}"
                    )
    return [fid_score, fid_score_lower_bound, fid_score_upper_bound]

