import argparse
import warnings
import sys
import os
import io
import albumentations as A
from torch.utils.data import Dataset
from PIL import Image
import cv2
import numpy as np
from albumentations.pytorch import ToTensorV2
from datasets import load_dataset, DownloadConfig


def get_parser():
    parser = argparse.ArgumentParser(description="AIGCDetection @cby Training")
    parser.add_argument("--model_name", default='convnext_base_in22k', help="Setting the model name", type=str)
    parser.add_argument("--embedding_size", default=1024, help="Setting the embedding_size", type=int)
    parser.add_argument("--num_classes", default=2, help="Setting the num classes", type=int)
    parser.add_argument('--freeze_extractor', action='store_true', help='Whether to freeze extractor?')
    parser.add_argument("--model_path", default=None, help="Setting the model path", type=str)
    parser.add_argument('--no_strict', action='store_true', help='Whether to load model without strict?')
    parser.add_argument("--root_path", default='/disk4/chenby/dataset/MSCOCO',
                        help="Setting the root path for dataset loader", type=str)
    parser.add_argument("--fake_root_path", default='/disk4/chenby/dataset/DRCT-2M',
                        help="Setting the fake root path for dataset loader", type=str)
    parser.add_argument('--is_dire', action='store_true', help='Whether to using DIRE?')
    parser.add_argument("--regex", default='*.*', help="Setting the regex for dataset loader", type=str)
    parser.add_argument('--test_all', action='store_true', help='Whether to test_all?')
    parser.add_argument('--post_aug_mode', default=None, help='Stetting the post aug mode during test phase.')
    parser.add_argument('--save_txt', default=None, help='Stetting the save_txt path.')
    parser.add_argument("--fake_indexes", default='1',
                        help="Setting the fake indexes, multi class using '1,2,3,...' ", type=str)
    parser.add_argument("--dataset_name", default='MSCOCO', help="Setting the dataset name", type=str)
    parser.add_argument("--device_id", default='0',
                        help="Setting the GPU id, multi gpu split by ',', such as '0,1,2,3'", type=str)
    parser.add_argument("--input_size", default=224, help="Image input size", type=int)
    parser.add_argument('--is_crop', action='store_true', help='Whether to crop image?')
    parser.add_argument("--batch_size", default=64, help="Setting the batch size", type=int)
    parser.add_argument("--epoch_start", default=0, help="Setting the epoch start", type=int)
    parser.add_argument("--num_epochs", default=50, help="Setting the num epochs", type=int)
    parser.add_argument("--num_workers", default=4, help="Setting the num workers", type=int)
    parser.add_argument('--is_warmup', action='store_true', help='Whether to using lr warmup')
    parser.add_argument("--lr", default=1e-3, help="Setting the learning rate", type=float)
    parser.add_argument("--save_flag", default='', help="Setting the save flag", type=str)
    parser.add_argument("--sampler_mode", default='', help="Setting the sampler mode", type=str)
    parser.add_argument('--is_test', action='store_true', help='Whether to predict the test set?')
    parser.add_argument('--is_amp', action='store_true', help='Whether to using amp autocast(使用混合精度加速)?')
    parser.add_argument("--inpainting_dir", default='full_inpainting', help="rec_image dir", type=str)
    parser.add_argument("--threshold", default=0.5, help="Setting the valid or testing threshold.", type=float)
    parser.add_argument("opts", help="Modify config options using the command-line", default=None,
                        nargs=argparse.REMAINDER)
    args = parser.parse_args()

    return args


warnings.filterwarnings("ignore")
sys.path.append('..')
args = get_parser()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device_id)

import torch
import torch.nn as nn
import time
from torch.utils.data import DataLoader
from tqdm import tqdm
import gc
from sklearn.metrics import roc_auc_score, recall_score, precision_score, accuracy_score, f1_score

from utils.utils import AverageMeter, Test_time_agumentation, calculate_fnr
from network.models import get_models


def merge_tensor(img, label, is_train=True):
    # img can be list[tensor] or tensor
    if isinstance(img, list):
        img = torch.cat(img, dim=0)

    # label can be list[int] or tensor
    if isinstance(label, list):
        label = torch.tensor(label, dtype=torch.long)

    if is_train:
        idx = torch.randperm(img.size(0))
        img = img[idx]
        label = label[idx]
    return img, label

# 9 times
def TTA(model_, img, activation=nn.Softmax(dim=1)):
    # original 1
    outputs = activation(model_(img))
    tta = Test_time_agumentation()
    # 水平翻转 + 垂直翻转 2
    flip_imgs = tta.tensor_flip(img)
    for flip_img in flip_imgs:
        outputs += activation(model_(flip_img))
    # 2*3=6
    for flip_img in [img, flip_imgs[0]]:
        rot_flip_imgs = tta.tensor_rotation(flip_img)
        for rot_flip_img in rot_flip_imgs:
            outputs += activation(model_(rot_flip_img))

    outputs /= 9

    return outputs


def eval_model(model, epoch, eval_loader, is_save=True, is_tta=False, threshold=0.5, save_txt=None):
    model.eval()
    losses = AverageMeter()
    accuracies = AverageMeter()
    eval_process = tqdm(eval_loader)
    all_labels = []
    all_outputs = []
    per_model_results = {}  # Dictionary to hold outputs and labels per model_name

    with torch.no_grad():
        for i, (img, label, model_name) in enumerate(eval_process):
            img, label = merge_tensor(img, label, is_train=False)
            if i % 1 == 0:
                eval_process.set_description("Epoch: %d, Loss: %.4f, Acc: %.4f" %
                                             (epoch, losses.avg, accuracies.avg))
            img, label = img.cuda(), label.cuda()
            if not is_tta:
                y_pred = model(img)
                y_pred = nn.Softmax(dim=1)(y_pred)
            else:
                y_pred = TTA(model, img, activation=nn.Softmax(dim=1))
            batch_outputs = 1 - y_pred[:, 0]
            all_outputs.append(batch_outputs)
            all_labels.append(label)
            loss = criterion(y_pred, label)
            acc = (torch.max(y_pred.detach(), 1)[1] == label).sum().item() / img.size(0)
            losses.update(loss.item(), img.size(0))
            accuracies.update(acc, img.size(0))

            # Accumulate results per model_name in the batch.
            for j, mname in enumerate(model_name):
                if mname not in per_model_results:
                    per_model_results[mname] = {"outputs": [], "labels": []}
                per_model_results[mname]["outputs"].append(batch_outputs[j].unsqueeze(0))
                per_model_results[mname]["labels"].append(label[j].unsqueeze(0))

    # Global metrics
    outputs_tensor = torch.cat(all_outputs, dim=0).cpu()
    labels_tensor = torch.cat(all_labels, dim=0).cpu()
    # Ensure binary labels
    labels_np = labels_tensor.numpy()
    labels_np[labels_np > 0] = 1
    outputs_np = outputs_tensor.numpy()
    auc = roc_auc_score(labels_np, outputs_np)
    recall = recall_score(labels_np, outputs_np > threshold)
    precision = precision_score(labels_np, outputs_np > threshold)
    binary_acc = accuracy_score(labels_np, outputs_np > threshold)
    f1 = f1_score(labels_np, outputs_np > threshold)
    fnr = calculate_fnr(labels_np, outputs_np > threshold)
    print(f'Global -> AUC:{auc:.4f}-Recall:{recall:.4f}-Precision:{precision:.4f}-BinaryAccuracy:{binary_acc:.4f}, f1: {f1:.4f}, FNR: {fnr:.4f}')
    print("Val:\t Loss:{0:.4f} \t Acc:{1:.4f}".format(losses.avg, accuracies.avg))
    acc_avg = accuracies.avg

    # Metrics per model_name
    for mname, res in per_model_results.items():
        m_outputs = torch.cat(res["outputs"], dim=0).cpu()
        m_labels = torch.cat(res["labels"], dim=0).cpu()
        m_labels_np = m_labels.numpy()
        m_labels_np[m_labels_np > 0] = 1
        m_outputs_np = m_outputs.numpy()
        try:
            m_auc = roc_auc_score(m_labels_np, m_outputs_np)
            m_recall = recall_score(m_labels_np, m_outputs_np > threshold)
            m_precision = precision_score(m_labels_np, m_outputs_np > threshold)
            m_binary_acc = accuracy_score(m_labels_np, m_outputs_np > threshold)
            m_f1 = f1_score(m_labels_np, m_outputs_np > threshold)
            m_fnr = calculate_fnr(m_labels_np, m_outputs_np > threshold)
            print(f"Model: {mname} -> AUC:{m_auc:.4f} - Recall:{m_recall:.4f} - Precision:{m_precision:.4f} - Accuracy:{m_binary_acc:.4f} - f1:{m_f1:.4f} - FNR:{m_fnr:.4f}")
        except ValueError:
            # In case there is only one class in the samples.
            print(f"Model: {mname} -> Not enough class variation to compute metrics.")

    del all_outputs, all_labels, per_model_results, losses, accuracies
    gc.collect()

    if save_txt is not None:
        return binary_acc, auc, recall, precision, f1, fnr
    return acc_avg


# python train.py --device_id=0 --model_name=efficientnet-b0 --input_size=224 --batch_size=48 --fake_indexes=1 --is_amp --save_flag=
if __name__ == '__main__':
    batch_size = args.batch_size * torch.cuda.device_count()
    writeFile = f"../output/{args.dataset_name}/{args.fake_indexes.replace(',', '_')}/" \
                f"{args.model_name.split('/')[-1]}_{args.input_size}{args.save_flag}/logs"
    store_name = writeFile.replace('/logs', '/weights')
    print(f'Using gpus:{args.device_id},batch size:{batch_size},gpu_count:{torch.cuda.device_count()},num_classes:{args.num_classes}')
    # Load model
    model = get_models(model_name=args.model_name, num_classes=args.num_classes,
                       freeze_extractor=args.freeze_extractor, embedding_size=args.embedding_size)
    if args.model_path is not None:
        model.load_state_dict(torch.load(args.model_path, map_location='cpu'), strict=not args.no_strict)
        print('Model found in {}'.format(args.model_path))
    else:
        print('No model found, initializing random model.')
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model).cuda()
    else:
        model = model.cuda()
    criterion = nn.CrossEntropyLoss()
    # criterion = LabelSmoothing(smoothing=0.05).cuda(device_id)

    start = time.time()
    epoch_start = 1
    num_epochs = 1
    
    scratch_dir = os.environ.get("SCRATCH")
    cache_dir = os.path.join(scratch_dir, ".cache") if scratch_dir else None
    download_config = DownloadConfig(cache_dir=cache_dir) if cache_dir else None
    load_kwargs = {"download_config": download_config} if download_config else {}
    print("Loading Anonymous460/OpenFake test split from Hugging Face...")
    hf_dataset = load_dataset("Anonymous460/OpenFake", split="test", **load_kwargs)

    def create_val_transforms(size=300, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), is_crop=False):
        # resize_fuc = A.CenterCrop(height=size, width=size) if is_crop else A.Resize(height=size, width=size)
        resize_fuc = A.CenterCrop(height=size, width=size) if is_crop else A.LongestMaxSize(max_size=size)
        return A.Compose([
            resize_fuc,
            A.PadIfNeeded(min_height=size, min_width=size, border_mode=cv2.BORDER_CONSTANT, value=0),
            A.Normalize(mean=mean, std=std),
            ToTensorV2()
        ], additional_targets={'rec_image': 'image'})
    test_transforms = create_val_transforms(size=args.input_size, is_crop=args.is_crop)
    # Dataset backed by Hugging Face records
    class HuggingFaceOpenFakeDataset(Dataset):
        def __init__(self, dataset, transforms):
            self.dataset = dataset
            self.transforms = transforms

        def __len__(self):
            return len(self.dataset)

        def _to_pil(self, image_entry):
            if isinstance(image_entry, Image.Image):
                return image_entry.convert("RGB")
            if isinstance(image_entry, np.ndarray):
                return Image.fromarray(image_entry).convert("RGB")
            if isinstance(image_entry, (bytes, bytearray)):
                return Image.open(io.BytesIO(image_entry)).convert("RGB")
            if isinstance(image_entry, dict):
                if image_entry.get("bytes") is not None:
                    return Image.open(io.BytesIO(image_entry["bytes"])).convert("RGB")
                if image_entry.get("path"):
                    return Image.open(image_entry["path"]).convert("RGB")
            raise ValueError(f"Unsupported image entry type: {type(image_entry)}")

        def __getitem__(self, idx):
            example = self.dataset[idx]
            try:
                image = self._to_pil(example.get("image"))
            except (OSError, ValueError) as exc:
                print(f"Warning: skipping corrupted image entry at index {idx}: {exc}")
                return None
            image_np = np.array(image)
            transformed = self.transforms(image=image_np)
            pixel_values = transformed['image']

            raw_label = example.get("label", 0)
            if isinstance(raw_label, str):
                label = 0 if raw_label.lower() == "real" else 1
            else:
                label = int(raw_label)
            model_name = example.get("model")
            if label == 0:
                model_name = "real"
            elif not model_name:
                model_name = "unknown"

            return pixel_values, label, model_name

    def collate_fn(batch):
        # Filter out any failed-load samples
        batch = [item for item in batch if item is not None]
        if len(batch) == 0:
            return torch.empty((0,)), [], []
        pixel_vals = torch.stack([item[0] for item in batch], dim=0)
        labels = [item[1] for item in batch]
        model_names = [item[2] for item in batch]
        return pixel_vals, labels, model_names

    # Prepare DataLoader
    dataset = HuggingFaceOpenFakeDataset(hf_dataset, test_transforms)
    test_loader = DataLoader(dataset, batch_size=32, num_workers=args.num_workers, collate_fn=collate_fn)

    test_dataset_len = len(dataset)
    print('test_dataset_len:', test_dataset_len)
    out_metrics = eval_model(model, epoch_start, test_loader, is_save=False, is_tta=False,
                                threshold=args.threshold, save_txt=args.save_txt)
    print('Total time:', time.time() - start)
