import os 
import sys
YOUR_PATH = '/YOUR_ROOT_PATH'
sys.path.append(f'{YOUR_PATH}/') 

import argparse
import torch
import torch.nn.functional as F
import json
from tqdm import tqdm
import numpy as np
from collections import defaultdict
from train.apgd_train import apgd_train as apgd
from train import ComputeLossWrapper, str2bool, get_dataset
from CLIP_eval.eval_utils import load_clip_model
from torchvision import transforms
from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL, IMAGENET_100_CLASS_ID_TO_LABEL
from lambda_net import LambdaNetworkFactory
import open_clip
from open_flamingo.eval.models.utils import unwrap_model
from torch.utils.data import DataLoader


# Define the wrapper, same as in adversarial_train-main2.py
class ClipVisionModel(torch.nn.Module):
    def __init__(self, model, args, normalize):
        super().__init__()
        self.model = model
        self.args = args
        self.normalize = normalize

    def forward(self, vision, output_normalize):
        embedding = self.model(self.normalize(vision))
        if output_normalize:
            embedding = torch.nn.functional.normalize(embedding, dim=-1)
        return embedding

def load_models(checkpoint_path, args):
    """
    Builds and returns the models exactly as in adversarial_train-main2.py.
    
    It does the following:
        1. Loads the original CLIP model and obtains the image transformations.
        2. Loads the fine-tuned model using load_clip_model.
        3. Constructs a ClipVisionModel wrapper for both the original and the fine-tuned model.
        4. Creates the lambda network via LambdaNetworkFactory.
        5. If a checkpoint is provided, loads state-dicts for the main model
           and (if available) the lambda network.

    Args:
        checkpoint_path (str): Path to the checkpoint file.
        args: Parsed arguments (must include clip_model_name, pretrained, lambda_net, etc.)

    Returns:
        model: The fine-tuned ClipVisionModel.
        lambda_network: The lambda network as built via LambdaNetworkFactory.
    """
    # 1. Get the original CLIP model and image processing transforms.
    model_orig, _, image_processor = open_clip.create_model_and_transforms(args.clip_model_name, pretrained='openai')
    
    # 2. Load the fine-tuned model.
    model, _, _ = load_clip_model(args.clip_model_name, args.pretrained)
    
    # 3. Remove the Normalize transform to get a non-normalized preprocessor and extract the normalize step.
    preprocessor_without_normalize = transforms.Compose(image_processor.transforms[:-1])
    normalize = image_processor.transforms[-1]
    
    # 4. Wrap the models in ClipVisionModel.
    if torch.cuda.device_count() > 1:
        model_orig = torch.nn.DataParallel(ClipVisionModel(model=model_orig.visual, args=args, normalize=normalize))
    else:
        model_orig = ClipVisionModel(model=model_orig.visual, args=args, normalize=normalize)
    model_orig = model_orig.cuda()

    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(ClipVisionModel(model=model.visual, args=args, normalize=normalize))
    else:
        model = ClipVisionModel(model=model.visual, args=args, normalize=normalize)
    model = model.cuda()

    # 5. Create the lambda network.
    lambda_network = LambdaNetworkFactory.create_network(args.lambda_net, model_orig=model_orig, clip_model_name=args.clip_model_name)
    if torch.cuda.device_count() > 1:
        lambda_network = torch.nn.DataParallel(lambda_network)
    lambda_network = lambda_network.cuda()

    # 6. If a checkpoint is provided, load state dicts.
    if checkpoint_path and checkpoint_path != "":
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Load main model checkpoint
        model_checkpoint_path = os.path.join(checkpoint_path, 'final.pt')
        checkpoint = torch.load(model_checkpoint_path, map_location=device)
        if 'model_state_dict' in checkpoint:
            state_dict = checkpoint['model_state_dict']
        else:
            state_dict = checkpoint
        adjusted_state_dict = {f"model.{k}": v for k, v in state_dict.items()}
        unwrap_model(model).load_state_dict(adjusted_state_dict)
        
        # Load lambda network checkpoint
        lambda_checkpoint_path = os.path.join(checkpoint_path, 'lambda_net_final.pt')
        if os.path.exists(lambda_checkpoint_path):
            lambda_checkpoint = torch.load(lambda_checkpoint_path, map_location=device)
            if isinstance(lambda_checkpoint, dict) and 'lambda_state_dict' in lambda_checkpoint:
                lambda_state_dict = lambda_checkpoint['lambda_state_dict']
            else:
                lambda_state_dict = lambda_checkpoint
            unwrap_model(lambda_network).mlp.load_state_dict(lambda_state_dict)

    return model, lambda_network

def analyze_lambda_network(model, lambda_network, dataloader, model_orig, args):
    """
    Analyzes the lambda network's behavior across a dataset and collects relevant metrics.
    
    Args:
        model: The fine-tuned model
        lambda_network: The network outputting susceptibility scores
        dataloader: DataLoader containing the dataset to analyze
        model_orig: Original CLIP model for computing clean embeddings
        args: Arguments containing model configuration
    
    Returns:
        Dictionary containing collected statistics
    """
    model.eval()
    lambda_network.eval()
    model_orig.eval()
    
    results = defaultdict(list)
    
    for data, targets in tqdm(dataloader, desc="Analyzing lambda network"):
        data = data.cuda()
        targets = targets.cuda()
        
        # Get original embeddings and lambda values for clean images
        with torch.no_grad():
            embedding_orig = model_orig(vision=data, output_normalize=args.output_normalize)
            lambda_values_clean = lambda_network(data, args.output_normalize, embedding_orig)
            
            # Get clean predictions
            embedding_clean = model(data, output_normalize=args.output_normalize)
            logits_clean = embedding_clean @ args.embedding_text_labels_norm
            preds_clean = logits_clean.max(dim=1)[1]
        
        # Generate adversarial examples
        loss_wrapper = ComputeLossWrapper(
            embedding_orig=None,
            embedding_text_labels_norm=args.embedding_text_labels_norm,
            reduction='none',
            loss='ce',
            logit_scale=100.
        )
        
        data_adv = apgd(
            model=model,
            loss_fn=loss_wrapper,
            x=data,
            y=targets,
            norm=args.norm,
            eps=args.eps,
            n_iter=10,
            verbose=False
        )
        
        # Get lambda values and predictions for adversarial images
        with torch.no_grad():
            embedding_orig_adv = model_orig(vision=data_adv, output_normalize=args.output_normalize)
            lambda_values_adv = lambda_network(data_adv, args.output_normalize, embedding_orig_adv)
            
            embedding_adv = model(data_adv, output_normalize=args.output_normalize)
            logits_adv = embedding_adv @ args.embedding_text_labels_norm
            preds_adv = logits_adv.max(dim=1)[1]
            
            # Calculate metrics
            correct_clean = preds_clean.eq(targets)
            correct_adv = preds_adv.eq(targets)
            is_robust = correct_clean & correct_adv
            is_vulnerable = correct_clean & ~correct_adv
            
            # Store results
            results['lambda_values_clean'].extend(lambda_values_clean.cpu().numpy().tolist())
            results['lambda_values_adv'].extend(lambda_values_adv.cpu().numpy().tolist())
            results['correct_clean'].extend(correct_clean.cpu().numpy().tolist())
            results['correct_adv'].extend(correct_adv.cpu().numpy().tolist())
            results['is_robust'].extend(is_robust.cpu().numpy().tolist())
            results['is_vulnerable'].extend(is_vulnerable.cpu().numpy().tolist())
            results['true_labels'].extend(targets.cpu().numpy().tolist())
            results['pred_clean'].extend(preds_clean.cpu().numpy().tolist())
            results['pred_adv'].extend(preds_adv.cpu().numpy().tolist())
            
            # Calculate cosine similarity between clean and adversarial embeddings
            cos_sim = F.cosine_similarity(embedding_clean, embedding_adv, dim=1)
            results['embedding_similarity'].extend(cos_sim.cpu().numpy().tolist())
    
    # Convert lists to numpy arrays for easier analysis
    for key in results:
        results[key] = np.array(results[key])
    
    # Calculate summary statistics
    stats = {
        'mean_lambda_clean': float(np.mean(results['lambda_values_clean'])),
        'std_lambda_clean': float(np.std(results['lambda_values_clean'])),
        'mean_lambda_adv': float(np.mean(results['lambda_values_adv'])),
        'std_lambda_adv': float(np.std(results['lambda_values_adv'])),
        'mean_lambda_robust': float(np.mean(results['lambda_values_clean'][results['is_robust']])),
        'mean_lambda_vulnerable': float(np.mean(results['lambda_values_clean'][results['is_vulnerable']])),
        'mean_lambda_adv_correct': float(np.mean(results['lambda_values_adv'][results['correct_adv']])),
        'mean_lambda_adv_incorrect': float(np.mean(results['lambda_values_adv'][~results['correct_adv']])),
        'clean_accuracy': float(np.mean(results['correct_clean'])),
        'robust_accuracy': float(np.mean(results['correct_adv'])),
        'mean_embedding_similarity': float(np.mean(results['embedding_similarity'])),
    }
    
    # Add class-wise statistics
    class_stats = defaultdict(dict)
    for class_idx in np.unique(results['true_labels']):
        class_mask = results['true_labels'] == class_idx
        class_stats[int(class_idx)] = {
            'mean_lambda_clean': float(np.mean(results['lambda_values_clean'][class_mask])),
            'mean_lambda_adv': float(np.mean(results['lambda_values_adv'][class_mask])),
            'robust_accuracy': float(np.mean(results['correct_adv'][class_mask])),
            'num_samples': int(np.sum(class_mask))
        }
    
    # Save detailed results and summary statistics
    output = {
        'detailed_results': {k: v.tolist() for k, v in results.items()},
        'summary_stats': stats,
        'class_stats': class_stats
    }
    
    return output

def save_results(output, filepath):
    """Saves the analysis results to a JSON file."""
    with open(filepath, 'w') as f:
        json.dump(output, f, indent=2)

if __name__ == "__main__":
    # Add argument parsing for flexibility
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint', type=str, required=True, help='Path to model checkpoint')
    parser.add_argument('--output', type=str, default='lambda_analysis.json', help='Output file path')
    parser.add_argument('--pretrained', type=str, default='openai', help='Pretrained model')
    parser.add_argument('--clip_model_name', type=str, default='ViT-B-32', help='CLIP model architecture')
    parser.add_argument('--dataset', type=str, default='imagenet', help='Dataset to analyze')
    parser.add_argument('--imagenet_root', type=str, required=True, help='ImageNet dataset root directory')
    parser.add_argument('--output_normalize', type=str2bool, default=False, help='Whether to normalize output embeddings')
    parser.add_argument('--norm', type=str, default='linf', help='Norm for adversarial perturbation')
    parser.add_argument('--eps', type=float, default=4, 
                       help='Epsilon for adversarial perturbation (in [0, 255] range, will be divided by 255)')
    parser.add_argument('--template', type=str, default='std', help='Template for text prompts')
    parser.add_argument('--lambda_net', type=str, default='linear_mlp', help='Lambda network architecture')
    args = parser.parse_args()
    
    # Check if checkpoint files exist
    model_checkpoint = os.path.join(args.checkpoint, 'final.pt')
    lambda_checkpoint = os.path.join(args.checkpoint, 'lambda_net_final.pt')
    if not os.path.exists(model_checkpoint):
        raise FileNotFoundError(f"Model checkpoint not found at {model_checkpoint}")
    if not os.path.exists(lambda_checkpoint):
        print(f"Warning: Lambda network checkpoint not found at {lambda_checkpoint}")
    
    args.eps /= 255  # Scale epsilon to [0, 1] range
    
    # Load models and dataset
    model, lambda_network = load_models(args.checkpoint, args)
    
    # Get the original CLIP model and its transforms (for text encoding)
    _, _, image_processor = open_clip.create_model_and_transforms(args.clip_model_name, pretrained='openai')
    preprocessor_without_normalize = transforms.Compose(image_processor.transforms[:-1])
    
    # Pass imagenet_root from args to get_dataset
    _, val_dataset = get_dataset(
         dataset_name='imagenet', 
         transform=preprocessor_without_normalize,
         imagenet_root=args.imagenet_root
    )
    dataloader = DataLoader(val_dataset, batch_size=488, shuffle=False, num_workers=32)
    classes = IMAGENET_1K_CLASS_ID_TO_LABEL
    
    # --- Compute and assign text label embeddings ---
    # Here we follow the same procedure as in adversarial_train-main2.py.
    # After get_dataset is called, the global 'classes' should be defined.
    if args.template == 'std':
        template = 'This is a photo of a {}'
    elif args.template == 'blurry':
        template = 'This is a blurry photo of a {}'
    else:
        raise ValueError(f'Unknown template: {args.template}')
    print(f'template: {template}')
    texts = [template.format(c) for c in classes.values()]
    text_tokens = open_clip.tokenize(texts)
    # Get the original CLIP model for text encoding
    model_text, _, _ = open_clip.create_model_and_transforms(args.clip_model_name, pretrained='openai')
    model_text = model_text.cuda()

    with torch.no_grad():
        embedding_text_labels_norm = []
        for el in (text_tokens[:(len(classes) // 2)], text_tokens[(len(classes) // 2):]):
            # Split into batches to avoid memory issues
            embedding_text_labels_norm.append(
                model_text.encode_text(el.cuda(), normalize=True).detach().cpu()
            )
        embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T.cuda()
        
        # Verify normalization
        assert torch.allclose(
            F.normalize(embedding_text_labels_norm, dim=0),
            embedding_text_labels_norm
        )

        # Verify embedding dimensions
        if args.clip_model_name in ('ViT-B-32', 'ViT-B-32-quickgelu', 'ViT-B-16'):
            assert embedding_text_labels_norm.shape == (512, len(classes)), embedding_text_labels_norm.shape
        elif args.clip_model_name in ('ViT-L-14', 'ViT-L-14-336'):
            assert embedding_text_labels_norm.shape == (768, len(classes)), embedding_text_labels_norm.shape
        else:
            raise ValueError(f'Unknown model: {args.clip_model_name}')

    # Clean up text model
    del model_text
    torch.cuda.empty_cache()
    args.embedding_text_labels_norm = embedding_text_labels_norm
    
    # --- Run analysis (using the loaded model as both model and model_orig) ---
    output = analyze_lambda_network(model, lambda_network, dataloader, model, args)
    
    # Save results
    save_results(output, args.output)
    print(f"Analysis complete. Results saved to {args.output}")