import argparse
import warnings
import sys
import os
import numpy as np

def get_parser():
    parser = argparse.ArgumentParser(description="AIGCDetection Training")
    parser.add_argument('--lamb', default=0.3, type=float, help='Whether to not using mixed pipeline?')
    parser.add_argument('--mixratio', default=0.5, type=float, help='Whether to not using mixed pipeline?')
    parser.add_argument('--mixprob', default=0.5, type=float, help='Whether to not using mixed pipeline?')
    parser.add_argument('--check', default=1000, type=int, help='Whether to not using mixed pipeline?')
    parser.add_argument('--nomix', action='store_true', help='Whether to not using mixed pipeline?')
    parser.add_argument('--noised', action='store_false', help='Whether to using noised pipeline?')
    parser.add_argument('--smashed', action='store_true', help='Whether to using noised pipeline?')
    parser.add_argument('--mixsmashed', action='store_true', help='Whether to using noised pipeline?')
    parser.add_argument('--contrastive', action='store_true', help='Whether to using contrastive pipeline?')
    parser.add_argument('--end_step', default=9, help='When to end noise?')
    parser.add_argument("--model_name", default='resnet3d', help="Setting the model name", type=str)
    parser.add_argument("--embedding_size", default=None, help="Setting the embedding_size", type=int)
    parser.add_argument("--num_classes", default=2, help="Setting the num classes", type=int)
    parser.add_argument('--freeze_extractor', action='store_true', help='Whether to freeze extractor?')
    parser.add_argument("--model_path", default=None, help="Setting the model path", type=str)
    parser.add_argument('--no_strict', action='store_true', help='Whether to load model without strict?')
    parser.add_argument("--root_path", default='/path/to/GenImage',
                        help="Setting the root path for dataset loader", type=str)
    parser.add_argument("--fake_root_path", default='',
                        help="Setting the fake root path for dataset loader", type=str)
    parser.add_argument('--is_dire', action='store_true', help='Whether to using DIRE?')
    parser.add_argument("--regex", default='*.*', help="Setting the regex for dataset loader", type=str)
    parser.add_argument('--test_all', action='store_true', help='Whether to test_all?')
    parser.add_argument('--post_aug_mode', default=None, help='Stetting the post aug mode during test phase.')
    parser.add_argument('--save_txt', default=None, help='Stetting the save_txt path.')
    parser.add_argument("--fake_indexes", default='1',
                        help="Setting the fake indexes, multi class using '1,2,3,...' ", type=str)
    parser.add_argument("--train_fake_indexes", default='1',
                        help="Setting the fake indexes, multi class using '1,2,3,...' ", type=str)
    parser.add_argument("--eval_fake_indexes", default='3',
                        help="Setting the fake indexes, multi class using '1,2,3,...' ", type=str)
    parser.add_argument("--dataset_name", default='GenImage', help="Setting the dataset name", type=str)
    parser.add_argument("--device_id", default='0',
                        help="Setting the GPU id, multi gpu split by ',', such as '0,1,2,3'", type=str)
    parser.add_argument("--input_size", default=224, help="Image input size", type=int)
    parser.add_argument('--is_crop', action='store_true', help='Whether to crop image?')
    parser.add_argument("--batch_size", default=64, help="Setting the batch size", type=int)
    parser.add_argument("--epoch_start", default=0, help="Setting the epoch start", type=int)
    parser.add_argument("--num_epochs", default=1, help="Setting the num epochs", type=int)
    parser.add_argument("--num_workers", default=8, help="Setting the num workers", type=int)
    parser.add_argument('--is_warmup', action='store_true', help='Whether to using lr warmup')
    parser.add_argument("--lr", default=5e-4, help="Setting the learning rate", type=float)
    parser.add_argument("--save_flag", default='crop', help="Setting the save flag", type=str)
    parser.add_argument("--sampler_mode", default='', help="Setting the sampler mode", type=str)
    parser.add_argument('--is_test', action='store_true', help='Whether to predict the test set?')
    parser.add_argument('--is_amp', action='store_true', help='Whether to using amp autocast(使用混合精度加速)?')
    parser.add_argument("--inpainting_dir", default='full_inpainting', help="rec_image dir", type=str)
    parser.add_argument('--pos_margin', default=0.0, type=float)
    parser.add_argument('--neg_margin', default=1.0, type=float)
    parser.add_argument('--tau', default=0.5, type=float)
    parser.add_argument('--loss_name', default='ContrastiveLoss', type=str)
    parser.add_argument('--memory_size', default=None, type=int)
    parser.add_argument('--use_miner', action='store_true', help='Whether to using miner')
    parser.add_argument("--threshold", default=0.5, help="Setting the valid or testing threshold.", type=float)
    parser.add_argument("opts", help="Modify config options using the command-line", default=None,
                        nargs=argparse.REMAINDER)
    args = parser.parse_args()

    return args


warnings.filterwarnings("ignore")
sys.path.append('..')
args = get_parser()
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.device_id)

import torch
import torch.nn as nn
import torch.optim as optim
from catalyst.data import BalanceClassSampler
import time
from torch.utils.data import DataLoader
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import gc
from sklearn.metrics import roc_auc_score, recall_score, precision_score, accuracy_score, average_precision_score, f1_score
import pytorch_warmup as warmup

from utils.utils import Logger, AverageMeter, Test_time_agumentation, calculate_fnr
from network.models import get_models
from data.dataset_dino import AIGCDetectionDataset, CLASS2LABEL_MAPPING, GenImage_LIST
from data.transform import create_train_transforms, create_val_transforms
from utils.losses import LabelSmoothing, CombinedLoss, FocalLoss

def merge_tensor(img, label, is_train=True):
    def shuffle_tensor(img, label):
        indices = torch.randperm(img.size(0))
        return img[indices], label[indices]
    if isinstance(img, list) and isinstance(label, list):
        img, label = torch.cat(img, dim=0), torch.cat(label, dim=0)
        if is_train:
            img, label = shuffle_tensor(img, label)
    return img, label

def merge_triple_tensor(img, label, gt, is_train=True):
    def shuffle_tensor(img, label, gt):
        indices = torch.randperm(img.size(0))
        return img[indices], label[indices], gt[indices]
    if isinstance(img, list) and isinstance(label, list) and isinstance(gt, list):
        img, label, gt = torch.cat(img, dim=0), torch.cat(label, dim=0), torch.cat(gt, dim=0)
        if is_train:
            img, label, gt = shuffle_tensor(img, label, gt)
    return img, label, gt

# 9 times
def TTA(model_, img, activation=nn.Softmax(dim=1)):
    outputs = activation(model_(img))
    tta = Test_time_agumentation()
    flip_imgs = tta.tensor_flip(img)
    for flip_img in flip_imgs:
        outputs += activation(model_(flip_img))
    for flip_img in [img, flip_imgs[0]]:
        rot_flip_imgs = tta.tensor_rotation(flip_img)
        for rot_flip_img in rot_flip_imgs:
            outputs += activation(model_(rot_flip_img))

    outputs /= 9

    return outputs


def eval_model(model, epoch, eval_loader, is_save=True, is_tta=False, threshold=0.5, save_txt=None):
    model.eval()
    losses = AverageMeter()
    c_losses = AverageMeter()
    r_losses = AverageMeter()
    f_losses = AverageMeter()
    accuracies = AverageMeter()
    eval_process = tqdm(eval_loader)
    labels = []
    outputs = []
    with torch.no_grad():
        for i, (img, label, gt, _) in enumerate(eval_process):
            img, label = merge_tensor(img, label, is_train=False)
            if i > 0 and i % 1 == 0:
                eval_process.set_description("Epoch: %d, Loss: %.4f, C-Loss: %4f, R-Loss:%.4f, F-Loss:%.4f, Acc: %.4f" %
                                             (epoch, losses.avg, c_losses.avg, r_losses.avg, f_losses.avg, accuracies.avg))
                # eval_process.set_description("Epoch: %d, Loss: %.4f, Acc: %.4f" %
                #                              (epoch, losses.avg, accuracies.avg))
            img, label = img.cuda().float(), label.cuda()
            
            if not is_tta:
                y_pred, emb = model(img)
                y_pred = nn.Softmax(dim=1)(y_pred)
            else:
                y_pred = TTA(model, img, activation=nn.Softmax(dim=1))

            emb = emb.reshape(-1, 1024)
            gt = gt.reshape(-1)

            outputs.append(1 - y_pred[:, 0])
            labels.append(label)
            loss = criterion(y_pred, label)
            c_loss = contrastive_loss(emb, gt)
            fake_mask = (label == 1)
            real_mask = (label == 0)
                
            f_loss = torch.tensor(0.0).cuda()
            if fake_mask.sum() > 0:
                fake_pred = y_pred[fake_mask]
                fake_label = label[fake_mask]
                f_loss = criterion(fake_pred, fake_label)


            r_loss = torch.tensor(0.0).cuda()
            if real_mask.sum() > 0:
                real_pred = y_pred[real_mask]
                real_label = label[real_mask]
                r_loss = criterion(real_pred, real_label)
            
            r_losses.update(r_loss.item(), img.size(0))
            f_losses.update(f_loss.item(), img.size(0))

            acc = (torch.max(y_pred.detach(), 1)[1] == label).sum().item() / img.size(0)
            losses.update(loss.item(), img.size(0))
            c_losses.update(c_loss.item(), img.size(0))
            accuracies.update(acc, img.size(0))

    outputs = torch.cat(outputs, dim=0).cpu().numpy()
    labels = torch.cat(labels, dim=0).cpu().numpy()
    labels[labels > 0] = 1
    ap = average_precision_score(labels, outputs)
    auc = roc_auc_score(labels, outputs)
    recall = recall_score(labels, outputs > threshold)
    precision = precision_score(labels, outputs > threshold)
    binary_acc = accuracy_score(labels, outputs > threshold)

    binary_predictions = np.array([1 if p > threshold else 0 for p in outputs])
    real_indices = np.where(labels == 0)[0]
    r_acc = accuracy_score(labels[real_indices], binary_predictions[real_indices])
    fake_indices = np.where(labels == 1)[0]
    f_acc = accuracy_score(labels[fake_indices], binary_predictions[fake_indices])
    f1 = f1_score(labels, outputs > threshold)
    fnr = calculate_fnr(labels, outputs > threshold)
    print(f'AUC:{auc}-Recall:{recall}-Precision:{precision}-BinaryAccuracy:{binary_acc}, F-ACC:{f_acc}-R-ACC:{r_acc}, AP: {ap}, f1:{f1}, fnr:{fnr}')
    if is_save:
        train_logger.log(phase="val", values={
            'epoch': epoch,
            'loss': format(losses.avg, '.4f'),
            'c_loss': format(c_losses.avg, '.4f'), 
            'r_loss': format(r_losses.avg, '.4f'), 
            'f_loss': format(f_losses.avg, '.4f'),
            'acc': format(accuracies.avg, '.4f'),
            'lr': optimizer.param_groups[0]['lr']
        })
    print("Val:\t Loss:{0:.4f} C-Loss:{0:.4f} R-Loss:{0:.4f} F-Loss:{0:.4f} \t Acc:{1:.4f}".format(losses.avg, c_losses.avg, r_losses.avg, f_losses.avg, accuracies.avg))
    # print("Val:\t Loss:{0:.4f} \t Acc:{1:.4f}".format(losses.avg, accuracies.avg))
    acc_avg = accuracies.avg
    del outputs, labels, losses, r_losses, f_losses, accuracies
    # del outputs, labels, losses, accuracies
    gc.collect()

    if save_txt is not None:
        return binary_acc, auc, r_acc, f_acc, ap, recall, precision, f1, fnr
    return acc_avg

def train_contrastive_model(model, criterion, optimizer, epoch, scaler=None, search_best_acc=True, lamb=0.3, check=200, loss_name='ContrastiveLoss'):
    model.train()
    losses = AverageMeter()
    r_losses = AverageMeter()
    f_losses = AverageMeter()
    accuracies = AverageMeter()
    training_process = tqdm(train_loader)
    lamb = lamb
    weigh = 1.0 - lamb
    for i, (x, label, gt, _) in enumerate(training_process):
        # x, label = merge_tensor(x, label, is_train=True)
        x, label, gt = merge_triple_tensor(x, label, gt, is_train=True)
        optimizer.zero_grad()
        current_lr = optimizer.param_groups[0]['lr']
        if i > 0 and i % 1 == 0:
            training_process.set_description(
                "Epoch: %d, LR: %.8f, Loss: %.4f, R-Loss:%.4f F-Loss:%.4f, Acc: %.4f" % (
                    epoch, current_lr, losses.avg, r_losses.avg, f_losses.avg, accuracies.avg))
            # training_process.set_description(
            #     "Epoch: %d, LR: %.8f, Loss: %.4f, Acc: %.4f" % (
            #         epoch, current_lr, losses.avg, accuracies.avg))
        x = x.cuda().float()
        label = label.cuda()
        # label = Variable(torch.LongTensor(label).cuda(device_id))
        # Forward pass: Compute predicted y by passing x to the model
        if scaler is None:
            y_pred, emb = model(x)
            emb = emb.reshape(-1, 1024)
            gt = gt.reshape(-1)
            # Compute and print loss
            if loss_name == 'NTXentLoss':
                indices = np.random.choice(len(emb), size=256, replace=False)
                emb = emb[indices]
                gt = gt[indices]
            loss =  weigh * criterion(y_pred, label) +  lamb * contrastive_loss(emb, gt)
            acc = (torch.max(y_pred.detach(), 1)[1] == label).sum().item() / x.size(0)

            losses.update(loss.item(), x.size(0))
            accuracies.update(acc, x.size(0))

            loss.backward()
            optimizer.step()
        else:
            with autocast():
                y_pred, emb = model(x)
                emb = emb.reshape(-1, 1024)
                gt = gt.reshape(-1)
                if loss_name == 'NTXentLoss':
                    indices = np.random.choice(len(emb), size=256, replace=False)
                    emb = emb[indices]
                    gt = gt[indices]
                # Compute and print loss
                loss = weigh * criterion(y_pred, label) + lamb * contrastive_loss(emb, gt)
            acc = (torch.max(y_pred.detach(), 1)[1] == label).sum().item() / x.size(0)

            losses.update(loss.item(), x.size(0))
            accuracies.update(acc, x.size(0))

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        fake_mask = (label == 1)
        real_mask = (label == 0)
            
        f_loss = torch.tensor(0.0).cuda()
        if fake_mask.sum() > 0:
            fake_pred = y_pred[fake_mask]
            fake_label = label[fake_mask]
            f_loss = criterion(fake_pred, fake_label)

        r_loss = torch.tensor(0.0).cuda()
        if real_mask.sum() > 0:
            real_pred = y_pred[real_mask]
            real_label = label[real_mask]
            r_loss = criterion(real_pred, real_label)
        
        r_losses.update(r_loss.item(), x.size(0))
        f_losses.update(f_loss.item(), x.size(0))

        if i % check == 0:
            if search_best_acc:
                acc = eval_model(model, epoch, eval_loader, is_save=False)
            save_path = '{}/{}_loss{:.4f}_acc{:.4f}.pth'.format(store_name, i + epoch * 18000, loss, acc)
            print(save_path)
            torch.save(model.state_dict(), save_path)

        if args.is_warmup:
            with warmup_scheduler.dampening():
                scheduler.step()
    if not args.is_warmup:
        scheduler.step()
    train_logger.log(phase="train", values={
        'epoch': epoch,
        'loss': format(losses.avg, '.4f'),
        'r_loss': format(r_losses.avg, '.4f'), 
        'f_loss': format(f_losses.avg, '.4f'),
        'acc': format(accuracies.avg, '.4f'),
        'lr': optimizer.param_groups[0]['lr']
    })
    print("Train:\t Loss:{0:.4f} R-Loss:{0:.4f} F-Loss:{0:.4f} \t Acc:{1:.4f}".format(losses.avg, r_losses.avg, f_losses.avg, accuracies.avg))

    del losses, accuracies
    gc.collect()

def train_model(model, criterion, optimizer, epoch, scaler=None, search_best_acc=True, check=1000):
    model.train()
    losses = AverageMeter()
    accuracies = AverageMeter()
    training_process = tqdm(train_loader)
    for i, (x, label, _, _) in enumerate(training_process):
        x, label = merge_tensor(x, label, is_train=True)
        optimizer.zero_grad()
        current_lr = optimizer.param_groups[0]['lr']
        if i > 0 and i % 1 == 0:
            training_process.set_description(
                "Epoch: %d, LR: %.8f, Loss: %.4f, Acc: %.4f" % (
                    epoch, current_lr, losses.avg, accuracies.avg))
        x = x.cuda().float()
        label = label.cuda()
        # label = Variable(torch.LongTensor(label).cuda(device_id))
        # Forward pass: Compute predicted y by passing x to the model
        if scaler is None:
            
            y_pred, _ = model(x)
            # Compute and print loss
            loss = criterion(y_pred, label)
            acc = (torch.max(y_pred.detach(), 1)[1] == label).sum().item() / x.size(0)

            losses.update(loss.item(), x.size(0))
            accuracies.update(acc, x.size(0))

            loss.backward()
            optimizer.step()
        else:
            with autocast():
                y_pred, _ = model(x)
                # Compute and print loss
                loss = criterion(y_pred, label)
            acc = (torch.max(y_pred.detach(), 1)[1] == label).sum().item() / x.size(0)

            losses.update(loss.item(), x.size(0))
            accuracies.update(acc, x.size(0))

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

        if i % check == 0 and i > 0:
            if search_best_acc:
                acc = eval_model(model, epoch, eval_loader, is_save=False)
            save_path = '{}/{}_loss{:.4f}_acc{:.4f}.pth'.format(store_name, i + epoch * 18000, loss, acc)
            print(save_path)
            torch.save(model.state_dict(), save_path)

        if args.is_warmup:
            with warmup_scheduler.dampening():
                scheduler.step()
    if not args.is_warmup:
        scheduler.step()
    train_logger.log(phase="train", values={
        'epoch': epoch,
        'loss': format(losses.avg, '.4f'),
        'acc': format(accuracies.avg, '.4f'),
        'lr': optimizer.param_groups[0]['lr']
    })
    print("Train:\t Loss:{0:.4f} \t Acc:{1:.4f}".format(losses.avg, accuracies.avg))

    del losses, accuracies
    gc.collect()


# python train.py --device_id=0 --model_name=efficientnet-b0 --input_size=224 --batch_size=48 --fake_indexes=1 --is_amp --save_flag=
if __name__ == '__main__':
    batch_size = args.batch_size * torch.cuda.device_count()
    writeFile = f"../output/{args.dataset_name}/{args.fake_indexes.replace(',', '_')}/" \
                f"{args.model_name.split('/')[-1]}_{args.input_size}_{args.lr}_{args.save_flag}/logs"
    store_name = writeFile.replace('/logs', '/weights')
    print(f'Using gpus:{args.device_id},batch size:{batch_size},gpu_count:{torch.cuda.device_count()},num_classes:{args.num_classes}')
    
    is_train = not args.is_test

    # Load model
    model = get_models(model_name=args.model_name, num_classes=args.num_classes,
                       freeze_extractor=args.freeze_extractor, embedding_size=args.embedding_size)
    if args.model_path is not None:
        if is_train:
            model.load_state_dict(torch.load(args.model_path, map_location='cpu')['state_dict'], strict=args.no_strict)
            in_features = model.fc.in_features
            model.fc = nn.Linear(in_features, args.num_classes)
        else:
            if not 'dino' or 'clip' in args.model_name:
                in_features = model.fc.in_features
                model.fc = nn.Linear(in_features, args.num_classes)
            model.load_state_dict(torch.load(args.model_path, map_location='cpu'), strict=args.no_strict)
        print('Model found in {}'.format(args.model_path))
    else:
        print('No model found, initializing random model.')
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model).cuda()
    else:
        model = model.cuda()
    criterion = nn.CrossEntropyLoss()
    contrastive_loss = CombinedLoss(loss_name=args.loss_name, embedding_size=args.embedding_size,
                                    pos_margin=args.pos_margin, neg_margin=args.neg_margin, tau=args.tau,
                                    memory_size=args.memory_size, use_miner=args.use_miner, num_classes=args.num_classes)
    # criterion = LabelSmoothing(smoothing=0.05).cuda(device_id)
    
    if is_train:
        if store_name and not os.path.exists(store_name):
            os.makedirs(store_name)
        train_logger = Logger(model_name=writeFile, header=['epoch', 'loss', 'acc', 'lr'])
        # setting data loader
        xdl = AIGCDetectionDataset(nomix=args.nomix,noised=args.noised, mixratio=args.mixratio, mixprob=args.mixprob, smashed=args.smashed, mixsmashed=args.mixsmashed, end_step=args.end_step, root_path=args.root_path, fake_root_path=args.fake_root_path, fake_indexes=args.train_fake_indexes, phase='train',
                                   num_classes=args.num_classes, inpainting_dir=args.inpainting_dir, is_dire=args.is_dire,
                                   transform=create_train_transforms(size=args.input_size, is_crop=args.is_crop)
                                   )
        sampler = BalanceClassSampler(labels=xdl.get_labels(), mode=args.sampler_mode) if args.sampler_mode != '' else None  # "upsampling"
        train_loader = DataLoader(xdl, batch_size=batch_size, shuffle=sampler is None, num_workers=args.num_workers, sampler=sampler)
        train_dataset_len = len(xdl)

        xdl_eval = AIGCDetectionDataset(nomix=args.nomix, noised=args.noised, mixratio=args.mixratio, mixprob=args.mixprob, smashed=args.smashed, end_step=args.end_step, root_path=args.root_path, fake_root_path=args.fake_root_path, fake_indexes=args.eval_fake_indexes, phase='val',
                                        num_classes=args.num_classes, inpainting_dir=args.inpainting_dir, is_dire=args.is_dire,
                                        transform=create_val_transforms(size=args.input_size, is_crop=args.is_crop)
                                        )
        # xdl_eval = AIGCDetectionDataset(nomix=args.nomix, noised=args.noised,smashed=False, end_step=args.end_step, root_path=args.root_path, fake_root_path=args.fake_root_path, fake_indexes=args.eval_fake_indexes, phase='val',
        #                                 num_classes=args.num_classes, inpainting_dir=args.inpainting_dir, is_dire=args.is_dire,
        #                                 transform=create_val_transforms(size=args.input_size, is_crop=args.is_crop)
        #                                 )
        eval_loader = DataLoader(xdl_eval, batch_size=batch_size, shuffle=False, num_workers=args.num_workers)
        eval_dataset_len = len(xdl_eval)
        print('train_dataset_len:', train_dataset_len, 'eval_dataset_len:', eval_dataset_len)

        # optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
        optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=4e-5)
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.9)
        if not args.is_warmup:
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 5)
        else:
            num_steps = train_dataset_len * args.num_epochs
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_steps)
            warmup_scheduler = warmup.UntunedLinearWarmup(optimizer)

        best_acc = 0.5 if args.epoch_start == 0 else eval_model(model, args.epoch_start - 1, eval_loader, is_save=False)
        for epoch in range(args.epoch_start, args.num_epochs):
            if args.contrastive == True:
                print(args.mixratio)
                train_contrastive_model(model, criterion, optimizer, epoch, lamb=args.lamb, check=args.check, loss_name=args.loss_name, scaler=GradScaler() if args.is_amp else None)
            else:
                train_model(model, criterion, optimizer, epoch, scaler=GradScaler() if args.is_amp else None)
            if epoch % 1 == 0 or epoch == args.num_epochs - 1:
                acc = eval_model(model, epoch, eval_loader)
                if best_acc < acc:
                    best_acc = acc
                    save_path = '{}/{}_acc{:.4f}.pth'.format(store_name, epoch, acc)
                    if torch.cuda.device_count() > 1:
                        torch.save(model.module.state_dict(), save_path)
                    else:
                        torch.save(model.state_dict(), save_path)
            print(f'Current best acc:{best_acc}')
        last_save_path = '{}/last_acc{:.4f}.pth'.format(store_name, acc)
        if torch.cuda.device_count() > 1:
            torch.save(model.module.state_dict(), last_save_path)
        else:
            torch.save(model.state_dict(), last_save_path)
    else:
        print("Wrong place!")
        start = time.time()
        epoch_start = 1
        num_epochs = 1
        xdl_test = AIGCDetectionDataset(noised=args.noised, smashed=args.smashed, end_step=args.end_step, root_path=args.root_path, fake_root_path=args.fake_root_path, fake_indexes=args.fake_indexes,
                                        phase='test', num_classes=args.num_classes, is_dire=args.is_dire,
                                        post_aug_mode=args.post_aug_mode, regex=args.regex, inpainting_dir=args.inpainting_dir,
                                        transform=create_val_transforms(size=args.input_size, is_crop=args.is_crop)
                                        )
        test_loader = DataLoader(xdl_test, batch_size=batch_size, shuffle=False, num_workers=4)
        test_dataset_len = len(xdl_test)
        print('test_dataset_len:', test_dataset_len)
        out_metrics = eval_model(model, epoch_start, test_loader, is_save=False, is_tta=False,
                                 threshold=args.threshold, save_txt=args.save_txt)
        print('Total time:', time.time() - start)
        if args.save_txt is not None:
            os.makedirs(os.path.dirname(args.save_txt), exist_ok=True)
            acc, auc, r_acc, f_acc, ap, recall, precision, f1, fnr = out_metrics
            with open(args.save_txt, 'a') as file:
                if args.dataset_name == 'GenImage':
                    class_name = GenImage_LIST[int(args.fake_indexes)-1]
                if args.dataset_name == 'ForenSynths':
                    class_name = ForenSynths_LIST[int(args.fake_indexes)-1]
                if args.dataset_name == 'UnivFD':
                    class_name = UnivFD_LIST[int(args.fake_indexes)-1]
                else:
                    class_name = list(CLASS2LABEL_MAPPING.keys())[int(args.fake_indexes)]
                result_str = f'model_path:{args.model_path}, post_aug_mode:{args.post_aug_mode}, class_name:{class_name}\n' \
                             f'acc:{acc:.4f}, r_acc:{r_acc:.4f}, f_acc:{f_acc:.4f}, ap:{ap:.4f}, auc:{auc:.4f}, recall:{recall:.4f}, precision:{precision:.4f}, ' \
                             f'f1:{f1:.4f}, fnr: {fnr}\n'
                file.write(result_str)
            print(f'The result was saved in {args.save_txt}')
