import torch
import torchvision
import torchvision.transforms as transforms
import argparse
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset
import clip
import torch.nn.functional as F
import torch.nn as nn
from diffusion_clip import SDE_DIFFUSION_CIFAR10_Creator,AP_Forward_MANI,Guided_Diffusion_ImageNet_Creator
from autoattack import AutoAttack
from attacks.bpda import BPDA
from attacks.pgd import eot_pgd_attack

device = 'cuda' if torch.cuda.is_available() else 'cpu'
class_map = {}

def load_clip_model():
    model, preprocess = clip.load(args.clip_model, device=device)
    model.eval()
    labels = [
        "a photo of an airplane", "a photo of an automobile", "a photo of a bird",
        "a photo of a cat", "a photo of a deer", "a photo of a dog",
        "a photo of a frog", "a photo of a horse", "a photo of a ship", "a photo of a truck"
    ]
    tokens = clip.tokenize(labels).to(device)
    return model, tokens, preprocess


def parse_args():
    parser = argparse.ArgumentParser(description="CLIP Evaluation with Diffusion-based Purification")
    parser.add_argument("--num_test", type=int, default=512, help="Number of test samples")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size")
    parser.add_argument("--epsilon", type=float, default=0.03, help="Attack epsilon")
    parser.add_argument("--alpha", type=float, default=0.005, help="Attack step size")
    parser.add_argument("--pgd_iters", type=int, default=10, help="Number of PGD iterations")
    parser.add_argument("--eot_iter", type=int, default=10, help="Number of EOT iterations")
    parser.add_argument("--attack_t", type=int, default=100, help="Attack diffusion steps")
    parser.add_argument("--reverse_t", type=int, default=5, help="Reverse diffusion steps")
    parser.add_argument("--purify_t", type=int, default=100, help="Purification diffusion steps")
    parser.add_argument("--purify_reverse_t", type=int, default=100, help="Purification reverse diffusion steps")
    parser.add_argument("--data_seed", type=int, default=0, help="Random seed for data sampling")
    parser.add_argument("--noise_schedule", type=str, default="Original", help="Noise schedule")
    parser.add_argument("--sample_schedule", type=str, default="DDPM", help="Sampling schedule")
    parser.add_argument("--Imagenet", type=int,default=0, help="Use ImageNet dataset")
    parser.add_argument("--phase", type=int, nargs='+', default=9, help="Phase parameter for purification")
    parser.add_argument("--amplitude", type=int, nargs='+', default=13, help="Amplitude parameter for purification")
    parser.add_argument("--norm", type=str, default="attack norm")
    parser.add_argument("--clip_model", type=str, default="ViT-L/14")
    return parser.parse_args()


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

class CLIPWrapper(nn.Module):
    def __init__(self, clip_model, text_tokens):
        super().__init__()
        self.clip_model = clip_model
        self.text_tokens = text_tokens
        self.clip_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=device).view(1, 3, 1, 1)
        self.clip_std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device).view(1, 3, 1, 1)

    def forward(self, x):
        x = (x - self.clip_mean) / self.clip_std
        logits, _ = self.clip_model(x, self.text_tokens)
        return logits
    


def evaluate_all():
    set_seed(args.data_seed)
    clip_model, text_tokens, preprocess = load_clip_model()
    if args.Imagenet==1:
        print("ImageNet")
        imagenet_labels = [f"a photo of a {name}" for name in class_map.values()]
        text_tokens = clip.tokenize(imagenet_labels).to(device)
        diffusion = Guided_Diffusion_ImageNet_Creator()

        with open("./data/class_map.txt", 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) < 4: continue
                synset_id = parts[1]
                names = ' '.join(parts[3:]).split(',')
                class_map[synset_id] = names[0].strip()
    else:
        diffusion = SDE_DIFFUSION_CIFAR10_Creator()
        print("CIFAR10")
   
    clip_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073], device=device).view(1, 3, 1, 1)
    clip_std = torch.tensor([0.26862954, 0.26130258, 0.27577711], device=device).view(1, 3, 1, 1)

    wrapped_model = CLIPWrapper(clip_model, text_tokens)
    AA_attacker = AutoAttack(
        model=wrapped_model,
        norm=args.norm,
        eps=args.epsilon,
        version='standard',
        verbose=True,
        alpha=2
    )
    BPDA_attacker = BPDA(wrapped_model, attack_steps=args.iters, eps=args.epsilon, step_size=args.alpha, eot=args.eot_iter)


    for phase in args.phase:
        for amplitude in args.amplitude: 
            purify_model=AP_Forward_MANI(
                    diffusion=diffusion,
                    attack_steps=args.purify_t,
                    denoising_steps=args.purify_reverse_t,
                    sampling_method=args.sample_schedule,
                    clip_mean=clip_mean,
                    clip_std=clip_std, 
                    amplitude= amplitude,
                    phase=phase,
                    Imagenet=args.Imagenet
                ).to(device)
            purify_model.eval()

            transform = transforms.ToTensor()
            dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
            indices = np.random.RandomState(seed=args.data_seed).choice(len(dataset), args.num_test, False)
            test_loader = DataLoader(Subset(dataset, indices), batch_size=args.batch_size, shuffle=False)

            metrics = {'clean_clip': 0, 'adv_clip': 0, 'clean_purified': 0, 'adv_purified': 0, 'total1': 0}

            for x, y in tqdm(test_loader, desc="Evaluating clean data"):
                x, y = x.to(device), y.to(device)
                
                B1 = x.size(0)
                metrics['total1'] += B1

                with torch.no_grad():
                    logits_clean, _ = clip_model(x, text_tokens)
                    metrics['clean_clip'] += (logits_clean.argmax(1) == y).sum().item()
                    purified_clean = purify_model(x)
                    logits_clean_purified, _ = clip_model(purified_clean, text_tokens)
                    metrics['clean_purified'] += (logits_clean_purified.argmax(1) == y).sum().item()

                y_adv = y.clone()
                if args.method == "pgd":
                    x_adv = eot_pgd_attack(clip_model, x, y, eps=args.epsilon, alpha=args.alpha,
                               iters=args.iters, eot=args.eot_iter,
                               text_tokens=text_tokens, norm=args.norm)
                elif args.method == "autoattack":
                    x_adv = AA_attacker.run_standard_evaluation(x, y, bs=args.batch_size).clamp(0, 1)
                elif args.method == "bpda":
                    x_adv = BPDA_attacker.forward(x, y)
                    
                x_adv_norm = (x_adv - clip_mean) / clip_std
                with torch.no_grad():
                    logits_adv, _ = clip_model(x_adv_norm, text_tokens)
                    metrics['adv_clip'] += (logits_adv.argmax(1) == y_adv).sum().item()
                    purified_adv = purify_model(x_adv_norm)
                    logits_adv_purified, _ = clip_model(purified_adv, text_tokens)
                    metrics['adv_purified'] += (logits_adv_purified.argmax(1) == y_adv).sum().item()

            total1 = metrics['total1']
            print(f" Clean:{100 * metrics['clean_clip'] / total1:.2f}%")
            print(f"StandardAcc: {100 * metrics['clean_purified'] / total1:.2f}%")
            print(f"Adversarial:{100 * metrics['adv_clip'] / total1:.2f}%")
            print(f"Robust:{100 * metrics['adv_purified'] / total1:.2f}%")
          
if __name__ == "__main__":
    args = parse_args()
    evaluate_all()
