import sys
sys.path.append('/YOUR_ROOT_PATH/src') 
sys.path.append('/YOUR_ROOT_PATH/src/train') 
CHACHE_DIR = '/mnt/raid10/ak-research-01/ak-research-01/codes/.cache'

import numpy as np
np.float_ = np.float64
np.complex_ = np.complex128

from train.datasets import COCOFlickrDataset, ImageNetDataset 
from CLIP_eval.eval_utils import load_clip_model, load_clip_model_convnext, load_clip_model_conv 

import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms
import open_clip
from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL, IMAGENET_100_CLASS_ID_TO_LABEL
from train.apgd_train import apgd_train as apgd
from train.utils import str2bool, AverageMeter
import argparse
from torchvision.datasets import CIFAR10, CIFAR100
from tqdm import tqdm 
import time

parser = argparse.ArgumentParser(description='Benchmark a CLIP vision model on the entire test set')
parser.add_argument('--model_path', type=str, required=True, help='Path to the model checkpoint')
parser.add_argument('--clip_model_name', type=str, default='ViT-L-14', help='ViT-L-14, ViT-B-32') 
parser.add_argument('--dataset', type=str, default='imagenet', help='Dataset to evaluate on (imagenet, imagenet100, cifar10, cifar100)')
parser.add_argument('--imagenet_root', type=str, default='/mnt/datasets/imagenet', help='Imagenet dataset root directory')
parser.add_argument('--per_device_batch_size', type=int, default=128, help='Batch size per device')
parser.add_argument('--attack', type=str, default='apgd', help='Adversarial attack type')
parser.add_argument('--norm', type=str, default='linf', help='Norm for adversarial perturbation (linf, l2)')
parser.add_argument('--eps', type=float, default=4, help='Epsilon for adversarial perturbation (divided by 255 internally)')
parser.add_argument('--iterations_adv', type=int, default=50, help='Iterations for adversarial attack')
parser.add_argument('--num_workers', type=int, default=4, help='Number of workers for data loading')

classes = {}

class ClipVisionModel(torch.nn.Module):
    def __init__(self, model, normalize):
        super().__init__()
        self.model = model
        self.normalize = normalize

    def forward(self, vision, output_normalize=True):
        embedding = self.model(self.normalize(vision))
        if output_normalize:
            embedding = F.normalize(embedding, dim=-1)
        return embedding

class ComputeLossWrapper:
    def __init__(self, embedding_text_labels_norm, reduction='none', logit_scale=100.):
        self.embedding_text_labels_norm = embedding_text_labels_norm
        self.reduction = reduction
        self.logit_scale = logit_scale

    def __call__(self, embedding, targets):
        embedding = embedding[0] if isinstance(embedding, tuple) else embedding
        logits = embedding @ (self.logit_scale * self.embedding_text_labels_norm)
        loss = F.cross_entropy(logits, targets, reduction=self.reduction)
        return loss

@torch.no_grad()
def compute_acc(logits, targets):
    preds = logits.max(dim=1)[1].detach()
    acc = (preds.eq(targets).sum() / targets.shape[0]).item() * 100
    return acc

@torch.no_grad()
def compute_embedding_distances(embedding_clean, embedding_adv, embedding_orig=None, embedding_orig_adv=None):
    """
    Compute various distance metrics between embeddings
    """
    metrics = {}
    
    # L2 distances
    metrics['L2_adv_clean'] = torch.norm(embedding_adv - embedding_clean, dim=1).mean().item()
    
    if embedding_orig is not None:
        metrics['L2_adv_orig'] = torch.norm(embedding_adv - embedding_orig, dim=1).mean().item()
        metrics['L2_clean_orig'] = torch.norm(embedding_clean - embedding_orig, dim=1).mean().item()
        
        # Normalized L2 distance (L2 distance / L2 norm of original embedding)
        orig_norm = torch.norm(embedding_orig, dim=1).mean().item()
        metrics['L2_clean_orig_normalized'] = metrics['L2_clean_orig'] / orig_norm
        
    # Cosine similarities (1 - cosine distance)
    metrics['Cosine_sim_adv_clean'] = F.cosine_similarity(embedding_adv, embedding_clean, dim=1).mean().item()
    
    if embedding_orig is not None:
        metrics['Cosine_sim_adv_orig'] = F.cosine_similarity(embedding_adv, embedding_orig, dim=1).mean().item()
        metrics['Cosine_sim_clean_orig'] = F.cosine_similarity(embedding_clean, embedding_orig, dim=1).mean().item()
    
    # Add metrics between original model's clean and adversarial embeddings
    if embedding_orig is not None and embedding_orig_adv is not None:
        metrics['L2_orig_adv_orig_clean'] = torch.norm(embedding_orig_adv - embedding_orig, dim=1).mean().item()
        metrics['Cosine_sim_orig_adv_orig_clean'] = F.cosine_similarity(embedding_orig_adv, embedding_orig, dim=1).mean().item()
    
    return metrics

def evaluate_model(model, model_orig, dataloader, embedding_text_labels_norm, args):
    """
    Evaluate the model on the entire test set, computing clean and robust accuracy
    """
    model.eval()
    if model_orig is not None:
        model_orig.eval()
    
    # Meters for tracking metrics
    acc_meter = AverageMeter('Clean Accuracy')
    racc_meter = AverageMeter('Robust Accuracy')
    
    # Meters for embedding distances
    l2_adv_clean_meter = AverageMeter('L2_adv_clean')
    cosine_adv_clean_meter = AverageMeter('Cosine_sim_adv_clean')
    
    if model_orig is not None:
        l2_adv_orig_meter = AverageMeter('L2_adv_orig')
        l2_clean_orig_meter = AverageMeter('L2_clean_orig')
        l2_clean_orig_norm_meter = AverageMeter('L2_clean_orig_normalized')
        cosine_adv_orig_meter = AverageMeter('Cosine_sim_adv_orig')
        cosine_clean_orig_meter = AverageMeter('Cosine_sim_clean_orig')
        # New meters for original model metrics
        l2_orig_adv_orig_clean_meter = AverageMeter('L2_orig_adv_orig_clean')
        cosine_orig_adv_orig_clean_meter = AverageMeter('Cosine_sim_orig_adv_orig_clean')
    
    # Confusion matrices
    num_classes = embedding_text_labels_norm.shape[1]
    clean_confusion = torch.zeros(num_classes, num_classes).cuda()
    robust_confusion = torch.zeros(num_classes, num_classes).cuda()
    
    # Track computation time
    clean_time = 0
    robust_time = 0
    attack_time = 0
    
    total_samples = 0
    
    print(f"Starting evaluation on {args.dataset} with {args.norm} norm and ε={args.eps*255:.1f}/255...")
    for data, targets in tqdm(dataloader, desc="Evaluating"):
        data = data.cuda()
        targets = targets.cuda()
        batch_size = data.shape[0]
        total_samples += batch_size

        # Get original model embedding if available
        embedding_orig = None
        if model_orig is not None:
            with torch.no_grad():
                embedding_orig = model_orig(data, output_normalize=False)

        # Clean accuracy
        clean_start = time.time()
        with torch.no_grad():
            embedding_clean = model(data, output_normalize=False)
            logits_clean = embedding_clean @ embedding_text_labels_norm
            
            # Compute clean accuracy
            acc = compute_acc(logits_clean, targets)
            acc_meter.update(acc, batch_size)
            
            # Update clean confusion matrix
            clean_preds = logits_clean.argmax(dim=1)
            for t, p in zip(targets.cpu().numpy(), clean_preds.cpu().numpy()):
                clean_confusion[t, p] += 1
        clean_time += time.time() - clean_start

        # Robust accuracy
        attack_start = time.time()
        loss_eval_wrapper = ComputeLossWrapper(
            embedding_text_labels_norm=embedding_text_labels_norm,
            reduction='none'
        )
        
        # Generate adversarial examples
        data_adv = apgd(
            model=model,
            loss_fn=loss_eval_wrapper,
            x=data,
            y=targets,
            norm=args.norm,
            eps=args.eps,
            n_iter=args.iterations_adv,
            verbose=False
        )
        attack_time += time.time() - attack_start

        robust_start = time.time()
        with torch.no_grad():
            embedding_adv = model(data_adv, output_normalize=False)
            logits_adv = embedding_adv @ embedding_text_labels_norm
            
            # Get original model embedding for adversarial examples if available
            embedding_orig_adv = None
            if model_orig is not None:
                embedding_orig_adv = model_orig(data_adv, output_normalize=False)
            
            # Compute robust accuracy
            racc = compute_acc(logits_adv, targets)
            racc_meter.update(racc, batch_size)
            
            # Update robust confusion matrix
            robust_preds = logits_adv.argmax(dim=1)
            for t, p in zip(targets.cpu().numpy(), robust_preds.cpu().numpy()):
                robust_confusion[t, p] += 1
                
            # Calculate per-class drop in performance
            clean_correct = (clean_preds == targets).float()
            robust_correct = (robust_preds == targets).float()
            
            # Compute embedding distances
            distance_metrics = compute_embedding_distances(
                embedding_clean, 
                embedding_adv, 
                embedding_orig,
                embedding_orig_adv
            )
            
            # Update distance meters
            l2_adv_clean_meter.update(distance_metrics['L2_adv_clean'], batch_size)
            cosine_adv_clean_meter.update(distance_metrics['Cosine_sim_adv_clean'], batch_size)
            
            if model_orig is not None:
                l2_adv_orig_meter.update(distance_metrics['L2_adv_orig'], batch_size)
                l2_clean_orig_meter.update(distance_metrics['L2_clean_orig'], batch_size)
                l2_clean_orig_norm_meter.update(distance_metrics['L2_clean_orig_normalized'], batch_size)
                cosine_adv_orig_meter.update(distance_metrics['Cosine_sim_adv_orig'], batch_size)
                cosine_clean_orig_meter.update(distance_metrics['Cosine_sim_clean_orig'], batch_size)
                
                # Update new meters
                if 'L2_orig_adv_orig_clean' in distance_metrics:
                    l2_orig_adv_orig_clean_meter.update(distance_metrics['L2_orig_adv_orig_clean'], batch_size)
                    cosine_orig_adv_orig_clean_meter.update(distance_metrics['Cosine_sim_orig_adv_orig_clean'], batch_size)
                
        robust_time += time.time() - robust_start

    # Calculate per-class accuracy
    per_class_clean_acc = torch.diag(clean_confusion) / clean_confusion.sum(dim=1).clamp(min=1)
    per_class_robust_acc = torch.diag(robust_confusion) / robust_confusion.sum(dim=1).clamp(min=1)

    results = {
        'Clean Accuracy': acc_meter.avg,
        'Robust Accuracy': racc_meter.avg,
        'Robustness Drop': acc_meter.avg - racc_meter.avg,
        'L2_adv_clean': l2_adv_clean_meter.avg,
        'Cosine_sim_adv_clean': cosine_adv_clean_meter.avg,
        'Clean Evaluation Time': clean_time,
        'Robust Evaluation Time': robust_time,
        'Attack Generation Time': attack_time,
        'Total Samples': total_samples,
        'Worst-5 Clean Classes': per_class_clean_acc.topk(k=5, largest=False),
        'Worst-5 Robust Classes': per_class_robust_acc.topk(k=5, largest=False),
        'Best-5 Clean Classes': per_class_clean_acc.topk(k=5, largest=True),
        'Best-5 Robust Classes': per_class_robust_acc.topk(k=5, largest=True),
        'Per-Class Clean Accuracy': per_class_clean_acc,
        'Per-Class Robust Accuracy': per_class_robust_acc,
    }
    
    # Add original model comparison metrics if available
    if model_orig is not None:
        results.update({
            'L2_adv_orig': l2_adv_orig_meter.avg,
            'L2_clean_orig': l2_clean_orig_meter.avg,
            'L2_clean_orig_normalized': l2_clean_orig_norm_meter.avg,
            'Cosine_sim_adv_orig': cosine_adv_orig_meter.avg,
            'Cosine_sim_clean_orig': cosine_clean_orig_meter.avg,
        })
        
        # Add the new metrics if available
        if hasattr(l2_orig_adv_orig_clean_meter, 'avg'):
            results.update({
                'L2_orig_adv_orig_clean': l2_orig_adv_orig_clean_meter.avg,
                'Cosine_sim_orig_adv_orig_clean': cosine_orig_adv_orig_clean_meter.avg,
            })
    
    return results

def get_dataset(dataset_name, transform, imagenet_root=None):
    """
    Load the evaluation dataset
    """
    global classes
    if imagenet_root is None:
        imagenet_root = args.imagenet_root

    if dataset_name == 'imagenet':
        dataset_eval = ImageNetDataset(
            root=imagenet_root + '/val',
            transform=transform,
        )
        classes = IMAGENET_1K_CLASS_ID_TO_LABEL
    elif dataset_name == 'imagenet100':
        dataset_eval = ImageNetDataset(
            root=imagenet_root + '/val',
            transform=transform,
        )
        classes = IMAGENET_100_CLASS_ID_TO_LABEL
    elif dataset_name == 'cifar10':
        dataset_eval = CIFAR10(
            root='/mnt/raid10/ak-research-01/ak-research-01/codes/.cache/cifar', 
            train=False, 
            download=True,
            transform=transform,    
        )
        classes = dataset_eval.classes  
        classes = dict(zip(range(len(classes)), classes))
    elif dataset_name == 'cifar100':
        dataset_eval = CIFAR100(
            root='./cifar', 
            train=False, 
            download=True,
            transform=transform,    
        )
        classes = dataset_eval.classes  
        classes = dict(zip(range(len(classes)), classes))
    else:
        raise ValueError(f'Dataset {dataset_name} not supported')
        
    return dataset_eval

def print_results(results, args, classes):
    """
    Print evaluation results in a well-formatted way
    """
    print("\n" + "="*80)
    print(f"BENCHMARK RESULTS FOR MODEL: {args.model_path}")
    print(f"Dataset: {args.dataset} | Model: {args.clip_model_name} | Attack: {args.attack}-{args.norm} (ε={args.eps*255:.1f}/255)")
    print("="*80)
    
    print(f"\n{'OVERALL METRICS':^80}")
    print("-"*80)
    print(f"Clean Accuracy:          {results['Clean Accuracy']:.2f}%")
    print(f"Robust Accuracy:         {results['Robust Accuracy']:.2f}%")
    print(f"Robustness Drop:         {results['Robustness Drop']:.2f}%")
    print(f"Robustness Retention:    {(results['Robust Accuracy'] / results['Clean Accuracy'] * 100):.2f}%")
    
    print(f"\n{'EMBEDDING DISTANCE METRICS':^80}")
    print("-"*80)
    print(f"L2 (Adv, Clean):                   {results['L2_adv_clean']:.4f}")
    print(f"Cosine Similarity (Adv, Clean):    {results['Cosine_sim_adv_clean']:.4f}")
    
    if 'L2_adv_orig' in results:
        print(f"L2 (Adv, Original):                {results['L2_adv_orig']:.4f}")
        print(f"L2 (Clean, Original):              {results['L2_clean_orig']:.4f}")
        print(f"Normalized L2 (Clean, Original):   {results['L2_clean_orig_normalized']:.4f}")
        print(f"Cosine Similarity (Adv, Original): {results['Cosine_sim_adv_orig']:.4f}")
        print(f"Cosine Similarity (Clean, Orig):   {results['Cosine_sim_clean_orig']:.4f}")
        
        # Print new metrics if available
        if 'L2_orig_adv_orig_clean' in results:
            print(f"L2 (Orig-Adv, Orig-Clean):        {results['L2_orig_adv_orig_clean']:.4f}")
            print(f"Cosine Sim (Orig-Adv, Orig-Clean): {results['Cosine_sim_orig_adv_orig_clean']:.4f}")
    
    print(f"\n{'TIMING INFORMATION':^80}")
    print("-"*80)
    print(f"Clean Evaluation Time:   {results['Clean Evaluation Time']:.2f}s")
    print(f"Attack Generation Time:  {results['Attack Generation Time']:.2f}s")
    print(f"Robust Evaluation Time:  {results['Robust Evaluation Time']:.2f}s")
    print(f"Total Samples Evaluated: {results['Total Samples']}")
    print(f"Time per Sample:         {(results['Clean Evaluation Time'] + results['Attack Generation Time'] + results['Robust Evaluation Time']) / results['Total Samples']:.4f}s")
    
    # Print worst and best classes if there aren't too many
    if len(classes) <= 100:
        print(f"\n{'CLASS-WISE PERFORMANCE':^80}")
        print("-"*80)
        
        # Worst classes for robust accuracy
        values, indices = results['Worst-5 Robust Classes']
        print("\nWorst 5 Classes (Robust Accuracy):")
        for i, idx in enumerate(indices):
            class_name = classes[idx.item()]
            clean_acc = results['Per-Class Clean Accuracy'][idx].item() * 100
            robust_acc = values[i].item() * 100
            drop = clean_acc - robust_acc
            print(f"{i+1}. {class_name:20} - Clean: {clean_acc:.2f}% | Robust: {robust_acc:.2f}% | Drop: {drop:.2f}%")
        
        # Best classes for robust accuracy
        values, indices = results['Best-5 Robust Classes']
        print("\nBest 5 Classes (Robust Accuracy):")
        for i, idx in enumerate(indices):
            class_name = classes[idx.item()]
            clean_acc = results['Per-Class Clean Accuracy'][idx].item() * 100
            robust_acc = values[i].item() * 100
            drop = clean_acc - robust_acc
            print(f"{i+1}. {class_name:20} - Clean: {clean_acc:.2f}% | Robust: {robust_acc:.2f}% | Drop: {drop:.2f}%")
    
    print("\n" + "="*80)

def main(args):
    """
    Main function to load model and evaluate it
    """
    # Set device and print basic information
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    
    print(f"Loading CLIP model: {args.clip_model_name}")
    # Load CLIP model and transforms
    # get models
    if args.clip_model_name == "hf-hub:laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg":
        model_orig, _, image_processor = open_clip.create_model_and_transforms(
        args.clip_model_name, cache_dir=CHACHE_DIR) 
    elif args.clip_model_name == "ViT-B-16-laion2B":
        model_orig, _, image_processor = open_clip.create_model_and_transforms(
        'hf-hub:laion/CLIP-ViT-B-16-laion2B-s34B-b88K', cache_dir=CHACHE_DIR)
    elif args.clip_model_name == "hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K":
        model_orig, _, image_processor = open_clip.create_model_and_transforms(
        'hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K', cache_dir=CHACHE_DIR)
    else: 
        model_orig, _, image_processor = open_clip.create_model_and_transforms(args.clip_model_name, pretrained='openai',
                         cache_dir=CHACHE_DIR)  

    # Setup transforms
    preprocessor_without_normalize = transforms.Compose(image_processor.transforms[:-1])
    normalize = image_processor.transforms[-1]

    # Load dataset
    print(f"Loading dataset: {args.dataset}")
    dataset_eval = get_dataset(args.dataset, preprocessor_without_normalize)
    print(f"Dataset size: {len(dataset_eval)} samples")
    
    dataloader_eval = DataLoader(
        dataset_eval, 
        batch_size=args.per_device_batch_size, 
        shuffle=False, 
        num_workers=args.num_workers,
        pin_memory=True
    )

    # Prepare text embeddings
    print("Preparing text embeddings...")
    template = 'This is a photo of a {}'
    texts = [template.format(c) for c in classes.values()]
    text_tokens = open_clip.tokenize(texts)
    
    model_orig.to(device)
    with torch.no_grad():
        embedding_text_labels_norm = []
        for el in (text_tokens[:(len(classes) // 2)], text_tokens[(len(classes) // 2):]):
            embedding_text_labels_norm.append(
                model_orig.encode_text(el.to(device), normalize=True).detach()
            )
        embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T.to(device)

    # Load the model to evaluate
    print(f"Loading pretrained model from: {args.model_path}")
    print(f"@@@ clip_model_name: {args.clip_model_name},  model_path: {args.model_path}")
    # assert False
    
    if args.clip_model_name == "hf-hub:laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg":
        model, _, _ = load_clip_model_conv(args.clip_model_name,args.model_path) 
    elif args.clip_model_name == "ViT-B-16-laion2B":
        model, _, _ = load_clip_model_conv(args.clip_model_name,args.model_path) 
    elif args.clip_model_name == "hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K":
        # model, _, _ = load_clip_model_conv(args.clip_model_name,args.model_path) 
        model, _, _ = load_clip_model_conv(args.clip_model_name, args.model_path)
    else: 
        model, _, _ = load_clip_model(args.clip_model_name, args.model_path)  

    # checkpoint = torch.load(args.model_path)
    # print(checkpoint.keys())
    # model.visual.load_state_dict(torch.load(args.model_path))
    # model, _, _ = load_clip_model(args.clip_model_name)
    # model, _, _ = open_clip.create_model_and_transforms(
    #     args.clip_model_name)
    # checkpoint = torch.load(args.model_path)
    # adjusted_state_dict = {f"model.{k}": v for k, v in checkpoint.items()}
    # modell = ClipVisionModel(model=model.visual, normalize=None)
    # modell.load_state_dict(adjusted_state_dict)
    # model.visual = modell.model
    
    num_gpus = torch.cuda.device_count()
    if num_gpus > 1:
        print(f'Number of GPUs available: {num_gpus}')
    else:
        print('No multiple GPUs available.')

    
    def to_parallel_and_cuda(model):
        if num_gpus > 1:
            model = torch.nn.DataParallel(model)
        return model.cuda()
    
    model_orig.cpu()
    model_orig_wrapped = ClipVisionModel(model=model_orig.visual, normalize=normalize)
    model_orig_wrapped = to_parallel_and_cuda(model_orig_wrapped)

    model = ClipVisionModel(model=model.visual, normalize=normalize)
    model = to_parallel_and_cuda(model)

    # Evaluate
    print("\nStarting full benchmark evaluation...")
    start_time = time.time()
    results = evaluate_model(model, model_orig_wrapped, dataloader_eval, embedding_text_labels_norm, args)
    total_time = time.time() - start_time
    
    # Add total time to results
    results['Total Time'] = total_time
    
    # Print results
    print_results(results, args, classes)
    
    print(f"\nTotal benchmark time: {total_time:.2f} seconds ({total_time/60:.2f} minutes)")

if __name__ == '__main__':
    args = parser.parse_args()
    # Scale epsilon to [0,1] range
    args.eps /= 255
    main(args)
