import os
import sys
import argparse
import json
import random
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
from torchvision.datasets import ImageFolder
import clip
from torch.utils.data import DataLoader, Subset

# ==========================================
# # Fix imports
# ==========================================
# 1. Retrieve the absolute path of this script
current_file_path = os.path.abspath(__file__)

# 2. Retrieve the directory of this script: .../dbViz/sharpness
current_dir = os.path.dirname(current_file_path)

# 3. Retrieve the parent directory (project root): .../dbViz
project_root = os.path.dirname(current_dir)

# 4. Append the target 'utils' directory path: .../dbViz/model-soups
model_soups_dir = os.path.join(project_root, 'model-soups')

# 5. Insert this directory at the beginning of sys.path
sys.path.insert(0, model_soups_dir)

# 6. Now you can directly import utils
try:
    from utils import get_model_from_sd
    print(f"Successfully imported 'get_model_from_sd' from {model_soups_dir}")
except ImportError as e:
    print(f"Error: Could not import utils from {model_soups_dir}")
    print(f"Current sys.path: {sys.path}")
    raise e

# ==========================================
# 1. Sharpness computation (no changes)
# ==========================================

def compute_sam_sharpness(net, dataloader, criterion, rho=0.05, device='cuda'):
    """Compute standard SAM Sharpness"""
    net.eval()
    total_sharpness = 0
    batch_sharpness = []
    
    # 1. Save original parameters
    # Parameters of DataParallel can still be accessed via .named_parameters()
    original_params = {name: param.clone() for name, param in net.named_parameters() if param.requires_grad}
    
    pbar = tqdm(dataloader, desc=f"SAM (rho={rho})")
    for inputs, targets in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        
        # 2. First forward pass: compute original loss and gradients
        net.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        
        # 3. Compute gradient norm
        grad_norm = 0
        param_grads = []
        for name, param in net.named_parameters():
            if param.grad is not None and param.requires_grad:
                grad_norm += torch.norm(param.grad.data) ** 2
                param_grads.append(param.grad.data.clone())
        grad_norm = torch.sqrt(grad_norm).item()
        
        # 4. Apply perturbation (modify the main model's parameters)
        if grad_norm > 1e-8:
            perturbation_scale = rho / (grad_norm + 1e-8)
            idx = 0
            for name, param in net.named_parameters():
                if param.grad is not None and param.requires_grad:
                    perturbation = param_grads[idx] * perturbation_scale
                    param.data.add_(perturbation) # In-place addition
                    idx += 1
        
        # 5. Second forward pass: compute perturbed loss
        # DataParallel will broadcast the modified parameters to all devices during this forward pass
        with torch.no_grad():
            outputs_perturbed = net(inputs)
            loss_perturbed = criterion(outputs_perturbed, targets).item()
        
        # 6. Restore original parameters
        for name, original_param in original_params.items():
            net.state_dict()[name].copy_(original_param)
        
        # Compute Sharpness
        sharpness_batch = max(0, loss_perturbed - loss.item())
        total_sharpness += sharpness_batch
        batch_sharpness.append(sharpness_batch)
        
        pbar.set_postfix({'Avg Sharpness': total_sharpness / len(batch_sharpness)})
        
    avg_sharpness = total_sharpness / len(batch_sharpness)
    return avg_sharpness

def compute_adaptive_sam_sharpness(net, dataloader, criterion, rho=0.05, device='cuda', lambda_reg=1e-8):
    """Compute Adaptive SAM (ASAM) Sharpness"""
    net.eval()
    total_sharpness = 0
    batch_sharpness = []
    
    original_params = {name: param.clone() for name, param in net.named_parameters() if param.requires_grad}
    
    pbar = tqdm(dataloader, desc=f"ASAM (rho={rho})")
    for inputs, targets in pbar:
        inputs, targets = inputs.to(device), targets.to(device)
        
        net.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        
        adaptive_grad_norm = 0
        param_info = []
        
        for name, param in net.named_parameters():
            if param.grad is not None and param.requires_grad:
                param_abs = param.data.abs() + lambda_reg
                adaptive_grad = param_abs * param.grad.data
                adaptive_grad_norm += torch.sum(adaptive_grad ** 2)
                param_info.append((name, param.grad.data.clone(), param_abs))
        
        adaptive_grad_norm = torch.sqrt(adaptive_grad_norm).item()
        
        if adaptive_grad_norm > 1e-8:
            perturbation_scale = rho / (adaptive_grad_norm + 1e-8)
            for name, grad, param_abs in param_info:
                perturbation = (param_abs ** 2) * grad * perturbation_scale
                param = dict(net.named_parameters())[name]
                param.data.add_(perturbation)

        with torch.no_grad():
            outputs_perturbed = net(inputs)
            loss_perturbed = criterion(outputs_perturbed, targets).item()
        
        for name, original_param in original_params.items():
            net.state_dict()[name].copy_(original_param)
        
        sharpness_batch = max(0, loss_perturbed - loss.item())
        total_sharpness += sharpness_batch
        batch_sharpness.append(sharpness_batch)
        
        pbar.set_postfix({'Avg ASAM': total_sharpness / len(batch_sharpness)})

    avg_sharpness = total_sharpness / len(batch_sharpness)
    return avg_sharpness

# ==========================================
# 2. Main Program
# ==========================================

def parse_arguments():
    parser = argparse.ArgumentParser(description='Sharpness Evaluation for CLIP on ImageNet')
    
    # Paths and model
    parser.add_argument('--data_location', type=str, default='/newdata_nvme/datasets/xxx/imagenet', help='ImageNet root')
    parser.add_argument('--model_path', type=str, required=True, help='Path to fine-tuned model .pt')
    parser.add_argument('--clip_backbone', type=str, default='ViT-B/32', help='CLIP backbone')
    parser.add_argument('--output_dir', type=str, default='./sharpness/results', help='Output dir')
    
    # Evaluation settings
    parser.add_argument('--rho', type=float, default=0.05, help='Perturbation radius')
    parser.add_argument('--subset_size', type=int, default=2048, help='Number of samples to evaluate')
    parser.add_argument('--batch_size', type=int, default=256, help='Batch size')
    parser.add_argument('--seed', type=int, default=0, help='Random seed')
    parser.add_argument('--workers', type=int, default=8, help='Num workers')
    
    return parser.parse_args()

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

def main():
    args = parse_arguments()
    print(json.dumps(vars(args), indent=2))
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    set_seed(args.seed)
    
    # 1. Prepare data
    # -------------------------------------------
    print(f"Loading CLIP backbone: {args.clip_backbone}")
    base_model, preprocess = clip.load(args.clip_backbone, device='cpu', jit=False)
    
    train_dir = os.path.join(args.data_location, 'train')
    if not os.path.exists(train_dir):
        train_dir = os.path.join(args.data_location, 'ILSVRC2012_img_train')
        if not os.path.exists(train_dir):
            train_dir = args.data_location
    print(f"Loading data from: {train_dir}")
    
    full_dataset = ImageFolder(root=train_dir, transform=preprocess)
    
    # Sample 2048 examples
    indices = torch.randperm(len(full_dataset))[:args.subset_size]
    subset_dataset = Subset(full_dataset, indices)
    
    dataloader = DataLoader(
        subset_dataset, 
        batch_size=args.batch_size, 
        shuffle=False, 
        num_workers=args.workers,
        pin_memory=True
    )

    # 2. Load model (using your utils)
    # -------------------------------------------
    print(f"Loading weights from {args.model_path}")
    state_dict = torch.load(args.model_path, map_location='cpu')
    
    # Handle possible nesting
    if 'state_dict' in state_dict:
        state_dict = state_dict['state_dict']
    

    model = get_model_from_sd(state_dict, base_model)
    
    # 3. Evaluate Sharpness
    # -------------------------------------------
    criterion = nn.CrossEntropyLoss()
    
    results = {
        "model_name": os.path.basename(args.model_path),
        "subset_size": args.subset_size,
        "rho": args.rho
    }

    # SAM
    print("\n>>> Computing Standard SAM Sharpness...")
    sam_score = compute_sam_sharpness(model, dataloader, criterion, rho=args.rho, device=device)
    results["sharpness_sam"] = sam_score
    print(f"SAM Sharpness: {sam_score:.6f}")

    # Adaptive SAM
    print("\n>>> Computing Adaptive SAM Sharpness...")
    asam_score = compute_adaptive_sam_sharpness(model, dataloader, criterion, rho=args.rho, device=device)
    results["sharpness_adaptive"] = asam_score
    print(f"Adaptive SAM Sharpness: {asam_score:.6f}")

    # 4. Save
    os.makedirs(args.output_dir, exist_ok=True)
    output_path = os.path.join(args.output_dir, f"sharpness_results_{args.rho}.jsonl")
    
    with open(output_path, 'a') as f:
        f.write(json.dumps(results) + "\n")
    print(f"Saved to {output_path}")

if __name__ == '__main__':
    main()