import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import numpy as np
from models.resnet import ResNet18
import matplotlib.pyplot as plt
from tqdm import tqdm

# CIFAR-10 Normalization constants
MEAN = torch.tensor([0.4914, 0.4822, 0.4465]).view(1, 3, 1, 1)
STD = torch.tensor([0.2023, 0.1994, 0.2010]).view(1, 3, 1, 1)

def tv_loss(img):
    """Total Variation loss to smooth the reconstructed image (Batch version)."""
    return torch.sum(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:])) + \
           torch.sum(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]))

def psnr(img1, img2):
    """Compute PSNR between two images [0, 1]."""
    mse = F.mse_loss(img1, img2)
    if mse == 0:
        return float('inf')
    return 20 * torch.log10(1.0 / torch.sqrt(mse))

def ssim(img1, img2, window_size=11, size_average=True):
    """Compute SSIM between two images [0, 1]."""
    # Constants
    C1 = 0.01 ** 2
    C2 = 0.03 ** 2
    
    # Create a 1D Gaussian window
    def gaussian(window_size, sigma):
        gauss = torch.Tensor([np.exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
        return gauss/gauss.sum()
    
    def create_window(window_size, channel):
        _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
        _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
        window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
        return window
    
    channel = img1.size(1)
    window = create_window(window_size, channel).to(img1.device)
    
    mu1 = F.conv2d(img1, window, padding=window_size//2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size//2, groups=channel)
    
    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2
    
    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size//2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size//2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size//2, groups=channel) - mu1_mu2
    
    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
    
    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

def reconstruct_batch(model, target_features, img_shape=(3, 32, 32), iterations=3000, lr=0.05):
    model.eval()
    device = target_features.device
    batch_size = target_features.shape[0]
    
    mean = MEAN.to(device)
    std = STD.to(device)
    
    # 1. Initialization
    z = torch.zeros((batch_size, *img_shape), requires_grad=True, device=device)
    
    # 2. Optimizer
    optimizer = optim.Adam([z], lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=iterations)
    
    pbar = tqdm(range(iterations), desc="Optimizing Batch")
    for i in pbar:
        optimizer.zero_grad()
        
        # Map to [0, 1]
        img = torch.sigmoid(z)
        
        # 3. Random Jitter (Batch version)
        off_x = np.random.randint(-3, 4)
        off_y = np.random.randint(-3, 4)
        img_shifted = torch.roll(img, shifts=(off_x, off_y), dims=(2, 3))
        
        # 4. Forward
        img_norm = (img_shifted - mean) / std
        features = model(img_norm, return_features=True)
        
        # Loss: Feature MSE + TV
        dist_loss = F.mse_loss(features, target_features)
        reg_loss = 1e-4 * tv_loss(img) / batch_size
        
        total_loss = dist_loss + reg_loss
        
        total_loss.backward()
        optimizer.step()
        scheduler.step()
        
        if (i+1) % 100 == 0:
            pbar.set_postfix(loss=f"{total_loss.item():.6f}")
            
    return torch.sigmoid(z).detach()

def run_large_scale_attack(task_id, num_samples=256, batch_size=128):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    data_path = f'exp_data/task_{task_id}_privacy_exp.pt'
    ckpt_path = f'checkpoints/task_{task_id}_models.pt'
    
    if not os.path.exists(data_path) or not os.path.exists(ckpt_path):
        print(f"Error: Task {task_id} data or checkpoint not found.")
        return

    data = torch.load(data_path)
    ckpt = torch.load(ckpt_path)
    
    # Create output directory
    output_dir = f'fi_results_task_{task_id}'
    os.makedirs(os.path.join(output_dir, 'individual'), exist_ok=True)
    
    # Initialize model
    model_args = {'latent_dim': 512, 'num_classes': 10, 'input_channels': 3}
    model = ResNet18(**model_args).to(device)
    model.load_state_dict(ckpt['resnet'])
    
    # Preparation
    baselines = ['raw_features', 'mixup_only', 'dp_only', 'full_lifil']
    titles = ['Ground Truth', 'Raw Feature', 'Mixup Only', 'DP Only', 'Full Li-FIL']
    
    # Denormalization constants
    mean_np = MEAN.squeeze().numpy()
    std_np = STD.squeeze().numpy()

    # Metrics storage - only for raw_features
    raw_psnr_list = []
    raw_ssim_list = []
    raw_attack_acc_list = []

    for start_idx in range(0, num_samples, batch_size):
        end_idx = min(start_idx + batch_size, num_samples)
        current_bs = end_idx - start_idx
        print(f"\n>>> Processing samples {start_idx} to {end_idx}...")
        
        # Get Ground Truth (denormalized for metric calculation and display)
        gt_imgs_norm = data['raw_images'][start_idx:end_idx].to(device)
        # Correct denormalization: img = normalized_img * std + mean
        gt_imgs = gt_imgs_norm * STD.to(device) + MEAN.to(device)
        gt_imgs = torch.clamp(gt_imgs, 0, 1)
        
        # Get ground truth labels
        gt_labels = data['raw_labels'][start_idx:end_idx].to(device)

        batch_results = [gt_imgs.cpu()]
        
        for key in baselines:
            print(f"Attacking {key}...")
            target_f = data[key][start_idx:end_idx].to(device)
            recon = reconstruct_batch(model, target_f, iterations=3000)
            
            # Only calculate metrics for raw_features
            if key == 'raw_features':
                # Normalize reconstructed images for model prediction
                recon_norm = (recon - MEAN.to(device)) / STD.to(device)
                
                # Get model predictions
                model.eval()
                with torch.no_grad():
                    outputs = model(recon_norm)
                    _, predicted = torch.max(outputs, 1)
                
                # Calculate metrics for each sample in batch
                for i in range(current_bs):
                    # PSNR
                    p = psnr(recon[i:i+1], gt_imgs[i:i+1])
                    raw_psnr_list.append(p.item())
                    
                    # SSIM
                    s = ssim(recon[i:i+1], gt_imgs[i:i+1])
                    raw_ssim_list.append(s.item())
                    
                    # Attack Accuracy (whether predicted label matches ground truth)
                    is_correct = (predicted[i] == gt_labels[i]).item()
                    raw_attack_acc_list.append(is_correct)
            
            batch_results.append(recon.cpu())

        # Save individual images
        print("Saving individual results...")
        for i in range(current_bs):
            sample_idx = start_idx + i
            plt.figure(figsize=(15, 3))
            for b_idx, res_batch in enumerate(batch_results):
                plt.subplot(1, 5, b_idx + 1)
                img_np = res_batch[i].permute(1, 2, 0).numpy()
                plt.imshow(np.clip(img_np, 0, 1))
                plt.title(titles[b_idx])
                plt.axis('off')
            
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, 'individual', f'sample_{sample_idx:03d}.png'))
            plt.close()

    # --- Quantified Results Analysis ---
    print("\n" + "="*50)
    print("FINAL QUANTITATIVE RESULTS (Raw Features Only)")
    print("="*50)
    
    avg_psnr = np.mean(raw_psnr_list)
    std_psnr = np.std(raw_psnr_list)
    avg_ssim = np.mean(raw_ssim_list)
    std_ssim = np.std(raw_ssim_list)
    avg_attack_acc = np.mean(raw_attack_acc_list) * 100  # Convert to percentage
    
    print(f"PSNR:        {avg_psnr:.2f} dB (±{std_psnr:.2f})")
    print(f"SSIM:        {avg_ssim:.4f} (±{std_ssim:.4f})")
    print(f"Attack Acc.: {avg_attack_acc:.2f}%")
    
    summary_lines = [
        f"PSNR:        {avg_psnr:.2f} dB (±{std_psnr:.2f})",
        f"SSIM:        {avg_ssim:.4f} (±{std_ssim:.4f})",
        f"Attack Acc.: {avg_attack_acc:.2f}%"
    ]
    
    with open(os.path.join(output_dir, 'summary.txt'), 'w') as f:
        f.write("\n".join(summary_lines))

    # Save a comparison plot for the paper (first 8 samples)
    plt.figure(figsize=(15, 20))
    for i in range(8):
        for b_idx in range(5):
            plt.subplot(8, 5, i*5 + b_idx + 1)
            img = batch_results[b_idx][i].permute(1, 2, 0).numpy()
            plt.imshow(np.clip(img, 0, 1))
            if i == 0: plt.title(titles[b_idx])
            plt.axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'paper_comparison_top8.png'))
    print(f"\nResults saved in {output_dir}")

if __name__ == '__main__':
    run_large_scale_attack(0, num_samples=256, batch_size=128)
