'''
# author: Zhiyuan Yan
# email: zhiyuanyan@link.cuhk.edu.cn
# date: 2023-0706
# description: Class for the EfficientDetector

Functions in the Class are summarized as:
1. __init__: Initialization
2. build_backbone: Backbone-building
3. build_loss: Loss-function-building
4. features: Feature-extraction
5. classifier: Classification
6. get_losses: Loss-computation
7. get_train_metrics: Training-metrics-computation
8. get_test_metrics: Testing-metrics-computation
9. forward: Forward-propagation

Reference:
@inproceedings{tan2019efficientnet,
  title={Efficientnet: Rethinking model scaling for convolutional neural networks},
  author={Tan, Mingxing and Le, Quoc},
  booktitle={International conference on machine learning},
  pages={6105--6114},
  year={2019},
  organization={PMLR}
}
'''

import os
import datetime
import logging
import io
import numpy as np
from sklearn import metrics
from typing import Union
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from base_metrics_class import calculate_metrics_for_train

from base_detector import AbstractDetector
from efficientnetb4 import EfficientNetB4
import random
from datasets import load_dataset, DownloadConfig

logger = logging.getLogger(__name__)

class EfficientDetector(AbstractDetector):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.backbone = self.build_backbone(config)
        self.prob, self.label = [], []
        self.correct, self.total = 0, 0
        
    def build_backbone(self, config):
        # prepare the backbone
        model_config = config['backbone_config']
        backbone = EfficientNetB4(model_config)
        #FIXME: current load pretrained weights only from the backbone, not here
        return backbone
    
    
    def features(self, data_dict: dict) -> torch.tensor:
        x = self.backbone.features(data_dict['image'])
        return x

    def classifier(self, features: torch.tensor) -> torch.tensor:
        return self.backbone.classifier(features)
    
    def get_losses(self, data_dict: dict, pred_dict: dict) -> dict:
        label = data_dict['label']
        pred = pred_dict['cls']
        loss = self.loss_func(pred, label)
        loss_dict = {'overall': loss}
        return loss_dict
    
    def get_train_metrics(self, data_dict: dict, pred_dict: dict) -> dict:
        label = data_dict['label']
        pred = pred_dict['cls']
        # compute metrics for batch data
        auc, eer, acc, ap = calculate_metrics_for_train(label.detach(), pred.detach())
        metric_batch_dict = {'acc': acc, 'auc': auc, 'eer': eer, 'ap': ap}
        return metric_batch_dict
    
    def get_test_metrics(self):
        y_pred = np.concatenate(self.prob)
        y_true = np.concatenate(self.label)
        # auc
        fpr, tpr, thresholds = metrics.roc_curve(y_true, y_pred, pos_label=1)
        auc = metrics.auc(fpr, tpr)
        # eer
        fnr = 1 - tpr
        eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
        # ap
        ap = metrics.average_precision_score(y_true,y_pred)
        # acc
        acc = self.correct / self.total
        # reset the prob and label
        self.prob, self.label = [], []
        self.correct, self.total = 0, 0
        return {'acc':acc, 'auc':auc, 'eer':eer, 'ap':ap, 'pred':y_pred, 'label':y_true}

    def forward(self, data_dict: dict, inference=False) -> dict:
        # get the features by backbone
        features = self.features(data_dict)
        # get the prediction by classifier
        pred = self.classifier(features)
        # get the probability of the pred
        prob = torch.softmax(pred, dim=1)[:, 1]
        # build the prediction dict for each output
        pred_dict = {'cls': pred, 'prob': prob, 'feat': features}
        if inference:
            self.prob.append(
                pred_dict['prob']
                .detach()
                .squeeze()
                .cpu()
                .numpy()
            )
            self.label.append(
                data_dict['label']
                .detach()
                .squeeze()
                .cpu()
                .numpy()
            )
            # deal with acc
            _, prediction_class = torch.max(pred, 1)
            correct = (prediction_class == data_dict['label']).sum().item()
            self.correct += correct
            self.total += data_dict['label'].size(0)
        return pred_dict


if __name__ == "__main__":
    # Example usage
    config = {
        'backbone_name': 'efficientnetb4',
        'backbone_config': {
            'num_classes': 2,
            'inc': 3,
            'dropout': False,
            'mode': 'original'  # or 'adjust_channel'
        }
    }
    model = EfficientDetector(config)
    
    # load model weights if available
    model_path = 'efficientnetb4.pth'
    if os.path.exists(model_path):
        model.load_state_dict(torch.load(model_path))
        print(f"Model weights loaded from {model_path}")
    
    import os
    import numpy as np
    import torch
    from torch.utils.data import DataLoader, Dataset
    from torchvision import transforms
    from PIL import Image
    from sklearn.metrics import (
        roc_auc_score,
        f1_score,
        accuracy_score,
        precision_recall_curve,
        average_precision_score,
    )
    
    scratch_dir = os.environ.get("SCRATCH")
    cache_dir = os.path.join(scratch_dir, ".cache") if scratch_dir else None
    download_config = DownloadConfig(cache_dir=cache_dir) if cache_dir else None
    load_kwargs = {"download_config": download_config} if download_config else {}
    print("Loading Anonymous460/OpenFake test split from Hugging Face...")
    hf_dataset = load_dataset("Anonymous460/OpenFake", split="test", **load_kwargs)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)
    test_transforms = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    # Dataset backed by Hugging Face records
    class HuggingFaceOpenFakeDataset(Dataset):
        def __init__(self, dataset, transforms):
            self.dataset = dataset
            self.transforms = transforms

        def __len__(self):
            return len(self.dataset)

        def _to_pil(self, image_entry):
            if isinstance(image_entry, Image.Image):
                return image_entry.convert("RGB")
            if isinstance(image_entry, np.ndarray):
                return Image.fromarray(image_entry).convert("RGB")
            if isinstance(image_entry, (bytes, bytearray)):
                return Image.open(io.BytesIO(image_entry)).convert("RGB")
            if isinstance(image_entry, dict):
                if image_entry.get("bytes") is not None:
                    return Image.open(io.BytesIO(image_entry["bytes"])).convert("RGB")
                if image_entry.get("path"):
                    return Image.open(image_entry["path"]).convert("RGB")
            raise ValueError(f"Unsupported image entry type: {type(image_entry)}")

        def __getitem__(self, idx):
            example = self.dataset[idx]
            try:
                image = self._to_pil(example.get("image"))
            except (OSError, ValueError) as exc:
                print(f"Warning: skipping corrupted image entry at index {idx}: {exc}")
                return None

            transformed = self.transforms(image)

            raw_label = example.get("label", 0)
            if isinstance(raw_label, str):
                label = 0 if raw_label.lower() == "real" else 1
            else:
                label = int(raw_label)

            model_name = example.get("model")
            if label == 0:
                model_name = "real"
            elif not model_name:
                model_name = "unknown"

            return transformed, label, model_name

    def collate_fn(batch):
        # Filter out any failed-load samples
        batch = [item for item in batch if item is not None]
        if len(batch) == 0:
            return torch.empty((0,)), [], []
        pixel_vals = torch.stack([item[0] for item in batch], dim=0)
        labels = [item[1] for item in batch]
        model_names = [item[2] for item in batch]
        return pixel_vals, labels, model_names

    # Prepare DataLoader
    dataset = HuggingFaceOpenFakeDataset(hf_dataset, test_transforms)
    loader = DataLoader(dataset, batch_size=128, num_workers=4, collate_fn=collate_fn)
    
    from tqdm import tqdm
    # Initialize metrics container
    # Global metrics accumulation
    all_labels = []
    all_preds = []
    all_scores = []
    all_model_names = []

    # Run inference and accumulate metrics
    with torch.no_grad():
        for pixel_values, labels, model_names in tqdm(loader):
            pixel_values = pixel_values.to(device)
            data_dict = {'image': pixel_values, 'label': torch.tensor(labels).to(device)}
            outputs = model(data_dict)
            probs = outputs['prob'].cpu().numpy()
            preds = torch.argmax(outputs['cls'], dim=1).cpu().numpy()
            # Probability for positive (fake) class
            # Accumulate global predictions and scores
            all_labels.extend(labels)
            all_preds.extend(preds.tolist())
            all_scores.extend(probs.tolist())
            all_model_names.extend(model_names)

    # === Find probability threshold that maximizes F1 ===
    precisions, recalls, pr_thresholds = precision_recall_curve(all_labels, all_scores)
    # The last precision/recall pair has no corresponding threshold
    f1s = 2 * precisions[:-1] * recalls[:-1] / (precisions[:-1] + recalls[:-1] + 1e-8)
    best_idx = int(np.argmax(f1s))
    best_threshold = float(pr_thresholds[best_idx])
    thresh_preds = (np.array(all_scores) >= best_threshold).astype(int)

    # === Per‑model metrics for argmax vs. optimal threshold ===
    metrics_argmax = defaultdict(lambda: {"TP": 0, "TN": 0, "FP": 0, "FN": 0})
    metrics_best   = defaultdict(lambda: {"TP": 0, "TN": 0, "FP": 0, "FN": 0})

    for label, arg_pred, best_pred, model_name in zip(all_labels, all_preds, thresh_preds.tolist(), all_model_names):
        # Argmax counts
        if label == 1:
            if arg_pred == 1:
                metrics_argmax[model_name]["TP"] += 1
            else:
                metrics_argmax[model_name]["FN"] += 1
        else:
            if arg_pred == 0:
                metrics_argmax[model_name]["TN"] += 1
            else:
                metrics_argmax[model_name]["FP"] += 1

        # Best‑threshold counts
        if label == 1:
            if best_pred == 1:
                metrics_best[model_name]["TP"] += 1
            else:
                metrics_best[model_name]["FN"] += 1
        else:
            if best_pred == 0:
                metrics_best[model_name]["TN"] += 1
            else:
                metrics_best[model_name]["FP"] += 1

    # Print side‑by‑side comparison
    for model_name in sorted(set(all_model_names)):
        m_arg  = metrics_argmax[model_name]
        m_best = metrics_best[model_name]

        # Argmax stats
        tot_a = sum(m_arg.values())
        acc_a = (m_arg["TP"] + m_arg["TN"]) / tot_a if tot_a else 0.0
        tpr_a = m_arg["TP"] / (m_arg["TP"] + m_arg["FN"]) if (m_arg["TP"] + m_arg["FN"]) else 0.0
        tnr_a = m_arg["TN"] / (m_arg["TN"] + m_arg["FP"]) if (m_arg["TN"] + m_arg["FP"]) else 0.0

        # Best‑threshold stats
        tot_b = sum(m_best.values())
        acc_b = (m_best["TP"] + m_best["TN"]) / tot_b if tot_b else 0.0
        tpr_b = m_best["TP"] / (m_best["TP"] + m_best["FN"]) if (m_best["TP"] + m_best["FN"]) else 0.0
        tnr_b = m_best["TN"] / (m_best["TN"] + m_best["FP"]) if (m_best["TN"] + m_best["FP"]) else 0.0

        print(f"Model: {model_name}")
        print(f"  Argmax        — Acc: {acc_a:.4f}, {m_arg['TP'] + m_arg['TN']}/{tot_a}")
        print(f"  Thr {best_threshold:.4f} — Acc: {acc_b:.4f}, {m_best['TP'] + m_best['TN']}/{tot_b}")

    # === Overall metrics ===
    overall_auc = roc_auc_score(all_labels, all_scores)
    overall_ap  = average_precision_score(all_labels, all_scores)

    # Metrics with default 0.5 threshold (argmax)
    argmax_f1 = f1_score(all_labels, all_preds)
    argmax_acc = accuracy_score(all_labels, all_preds)

    # Metrics with optimal threshold
    best_f1 = f1_score(all_labels, thresh_preds)
    best_acc = accuracy_score(all_labels, thresh_preds)

    print(f"Overall — AUC-ROC: {overall_auc:.4f}, AUC-PR: {overall_ap:.4f}")
    print(f"  Argmax (0.5)   — F1: {argmax_f1:.4f}, Acc: {argmax_acc:.4f}")
    print(f"  Best thr {best_threshold:.4f} — F1: {best_f1:.4f}, Acc: {best_acc:.4f}")
