import argparse
import os

import torch
import sys
sys.path.append(os.getcwd())
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
from baselines.ViT_LRP import vit_base_patch16_224
from baselines.ViT_LRP import deit_base_patch16_224
from torchvision.models import swin_b, Swin_B_Weights


from src.algorithms import dds, classification_with_dds, batch_pgd_attack
from src.diffusion import create_diffusion_model

def get_transforms():
    return transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # used by the authors in the notebook
        # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # imagenet statistics
    ])


def compute_accuracy(model, diffusion_model, dataloader, device, attack=True, dds=True):
    print("Evaluating the model...")
    model.eval()
    cnt = 0
    top1_cnt = 0
    top5_cnt = 0

    for inputs, labels in tqdm(dataloader, desc=f"Evaluating {model.__class__.__name__}"):
        inputs, labels = inputs.to(device), labels.to(device)

        if attack:
            inputs = batch_pgd_attack(inputs, model, noise_level=2/255)

        if dds:
            outputs = classification_with_dds(inputs, model, diffusion_model)
        else:
            outputs = model(inputs)

        # compute top1 and top5 accuracy
        _, top1_pred = outputs.topk(1, dim=1)
        _, top5_pred = outputs.topk(5, dim=1)

        # compare with ground truth
        top1_cnt += (top1_pred.squeeze() == labels).sum().item()
        top5_cnt += (top5_pred == labels.view(-1, 1)).sum().item()

        cnt += labels.size(0)
    
    top1_acc = 100 * top1_cnt / cnt
    top5_acc = 100 * top5_cnt / cnt

    print(f"Top-1 accuracy: {top1_acc}%.")
    print(f"Top-5 accuracy: {top5_acc}%.")


def get_model(model_name: str, device: torch.device):
    if model_name == "vit":
        model = vit_base_patch16_224(pretrained=True).to(device)
    elif model_name == "deit":
        model = deit_base_patch16_224(pretrained=True).to(device)
    elif model_name == "swin":
        model = swin_b(weights=Swin_B_Weights).to(device)

    return model


def run_model(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = get_model(args.model, device)
    diffusion_model = create_diffusion_model(args.diffusion)
    test_dataset = datasets.ImageFolder(args.dataset_path, transform=get_transforms())
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4)
    compute_accuracy(model, diffusion_model, test_loader, device, attack=args.attack, dds=args.dds)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default="vit", help="Name of the vit to be used")
    parser.add_argument("--attack", action='store_true', default=False, help="True if PGD attack is to be used")
    parser.add_argument("--batch-size", type=int, default=4, help='Batch size')
    parser.add_argument("--dds", action='store_true', default=False, help='To run DDS algorithm')
    parser.add_argument("--dataset-path", type=str, default="data_dir/val", help="Path to the dataset.")
    parser.add_argument("--diffusion", type=str, choices=['original', 'latent', 'stable'], default="original", help='which diffusion model to use')

    args = parser.parse_args()
    run_model(args)