import os
import argparse
import numpy as np
import torch
import utils
import random
import torchvision
from torchinfo import summary
from utils import load_pretrain, MLP
from utils import load_data
from utils import classify
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from tqdm import trange
from torch.utils.data import DataLoader, Subset
from sklearn.cluster import KMeans
from sklearn.cluster import MiniBatchKMeans
from sklearn.cluster import SpectralClustering
from sklearn.mixture import GaussianMixture







def cluster_features(features, method='sc'):

    if method == 'kmeans':
        kmeans = KMeans(random_state=42)
        cluster_ids = kmeans.fit_predict(features)
    elif method == 'mbk':
        mbk = MiniBatchKMeans(batch_size=256, random_state=42)
        cluster_ids = mbk.fit_predict(features)
    elif method == 'sc':
        sc = SpectralClustering(
        affinity='nearest_neighbors',
        n_neighbors=20,
        assign_labels='kmeans',
        random_state=42)
        cluster_ids = sc.fit_predict(features)
    elif method == 'gmm': 
        gmm = GaussianMixture(covariance_type='diag', random_state=42)
        cluster_ids = gmm.fit_predict(features)
    return cluster_ids




def project_onto_l2_ball(delta, eps):
    # flatten
    orig_shape = delta.shape
    delta_flat = delta.view(-1)
    norm = torch.norm(delta_flat, p=2)
    if norm <= eps or norm == 0:
        return delta
    else:
        return (delta_flat * (eps / norm)).view(orig_shape)


def optimize_members(
    members,                 # torch.Tensor [N, C, H, W] or list
    encoder,                 # encoder mapping images -> features (output shape [1, D])
    anchor_feats,            # numpy array or torch.Tensor [M, D]
    device='cuda',           # device
    iters=200,               # iterations per member
    lr=1e-3,                 # learning rate
    eps_l2=1.0,              # L2 perturbation budget
    optimizer_name='SGD',    # 'SGD' or 'Adam'
    loss_type='squared',     # 'squared' or 'l2' (squared = sum ||f-a||^2)
    verbose=True
):


    # prepare inputs
    if isinstance(members, list):
        members_t = torch.stack(members).float()
    else:
        members_t = members.float()

    N = members_t.shape[0]

    if isinstance(anchor_feats, np.ndarray):
        anchors = torch.from_numpy(anchor_feats).float().to(device)
    else:
        anchors = anchor_feats.float().to(device)

    encoder = encoder.to(device)
    encoder.eval()  # eval mode for stable BN/Dropout behaviour

    optimized_list = []
    stats = []

    iterator = range(N)
    if verbose:
        iterator = trange(N, desc="Optimizing members (L2 constraint)")

    for idx in iterator:
        x0 = members_t[idx].to(device)            # shape [C,H,W]
        x = x0.clone().detach()                   # will optimize x
        x.requires_grad_(True)

        # optimizer on the image tensor
        if optimizer_name.lower() == 'adam':
            opt = torch.optim.Adam([x], lr=lr)
        else:
            opt = torch.optim.SGD([x], lr=lr, momentum=0.95)

        # compute initial loss (sum of distances or squared distances)
        with torch.no_grad():
            feat0 = encoder(x.unsqueeze(0)).view(1, -1)  # (1,D)
            if loss_type == 'squared':
                # sum of squared distances to all anchors
                init_loss = torch.sum((feat0 - anchors) ** 2).item()
            else:
                # sum of l2 distances
                init_loss = torch.sum(torch.norm(feat0 - anchors, dim=1)).item()

        # optimization loop
        for it in range(iters):
            opt.zero_grad()

            feat = encoder(x.unsqueeze(0)).view(1, -1)  # (1,D)

            if loss_type == 'squared':
                loss = torch.sum((feat - anchors) ** 2)   # scalar
            else:
                loss = torch.sum(torch.norm(feat - anchors, dim=1))  # scalar

            # backward on image
            loss.backward()
            #print(loss)
            opt.step()
            
            # projection steps:
            # 1) clip pixel range
            with torch.no_grad():
                x.clamp_(0, 1)

                # 2) project perturbation to L2 ball: delta = x - x0
                delta = x - x0
                delta_proj = project_onto_l2_ball(delta, eps_l2)
                x.data = (x0 + delta_proj).data  # ensure in-place copy of tensor values

            # keep requires_grad True for next iter
            x.requires_grad_(True)

        # compute final loss
        with torch.no_grad():
            feat_final = encoder(x.unsqueeze(0)).view(1, -1)
            if loss_type == 'squared':
                final_loss = torch.sum((feat_final - anchors) ** 2).item()
            else:
                final_loss = torch.sum(torch.norm(feat_final - anchors, dim=1)).item()

        optimized_list.append(x.detach().cpu())
        stats.append({'index': idx, 'init_loss': init_loss, 'final_loss': final_loss})

        if verbose:
            iterator.set_postfix({'init': f"{init_loss:.4f}", 'final': f"{final_loss:.4f}"})

    optimized_tensor = torch.stack(optimized_list)  # [N,C,H,W]
    return optimized_tensor, stats







def generate_group_fingerprint(
        encoder, 
        key_dataset, 
        args,
        cluster_method='sc',
        save_root = None
        ):

    
    os.makedirs(save_root, exist_ok=True)

    print(f"Saving fingerprints to: {save_root}")

    # ---- Step 1: Randomly select 5000 samples ----
    total_indices = np.random.choice(len(key_dataset), 5000, replace=False)
    subset = Subset(key_dataset, total_indices)

    loader = DataLoader(subset, batch_size=128, shuffle=False)

    # ---- Step 2: Get features ----
    encoder = encoder.to(args.device)
    encoder.eval()

    key_feats = []
    key_images = []

    with torch.no_grad():
        for imgs, _ in loader:
            imgs = imgs.to(args.device)
            feats = encoder(imgs)
            key_feats.append(feats.cpu())
            key_images.append(imgs.cpu())  

    key_feats = torch.cat(key_feats, dim=0)        
    key_images = torch.cat(key_images, dim=0)      

    print(f"Extracted features: {key_feats.shape}")

    # ---- Step 3: Cluster ----
    cluster_ids = cluster_features(key_feats.numpy(), method=cluster_method)

    # ---- Step 4: Find clusters with a sample size > 50 ----
    cluster_ids = np.array(cluster_ids)
    unique, counts = np.unique(cluster_ids, return_counts=True)
    
    large_clusters = [u for u, c in zip(unique, counts) if c > 50]

    if len(large_clusters) == 0:
        raise ValueError("Error: None of the clusters in the clustering results have more than 50 samples.")

    print(f"Clusters with >50 samples: {large_clusters}")

    # ---- Step 5: Generate ----
    for i in range(args.num_group):
        print(f"\n=== Generating Group {i} ===")

        # -------- Step 5.1: Select anchor cluster --------
        
        anchor_cluster = random.choice(large_clusters)
        print(f"Selected anchor_cluster = {anchor_cluster}")

        anchor_mask = (cluster_ids == anchor_cluster)

        anchor_feats = key_feats[anchor_mask]   # (N_anchor, feat_dim)
        # Select any args.num_anchor samples from them as anchor_feats
        chosen_idx = np.random.choice(np.arange(len(anchor_feats)), args.num_anchor, replace=False)
        anchor_feats = anchor_feats[chosen_idx]
        
        print(f"Anchor cluster size = {anchor_feats.shape[0]}")
        
        # -------- Step 5.2: Randomly select args.num_member samples from other clusters as members --------
        other_indices = np.where(cluster_ids != anchor_cluster)[0]

        if len(other_indices) < args.num_member:
            raise ValueError("The sample available for members is insufficient.")

        chosen_member_idx = np.random.choice(other_indices, args.num_member, replace=False)
        members = key_images[chosen_member_idx]     

        # -------- Step 5.3: Save --------
        group_path = os.path.join(save_root, f"Group_{i}")
        os.makedirs(group_path, exist_ok=True)

        torch.save(anchor_feats, os.path.join(group_path, "anchor_feats.pth"))
        torch.save(members,     os.path.join(group_path, "members.pth"))

        print(f"Saved: {group_path}/anchor_feats.pth")
        print(f"Saved: {group_path}/members.pth")


def generate_FP(args):

    # ---------- Step 1. Load encoder ----------
    print("\n=== Loading encoder model ===")
    encoder = load_pretrain(args.pretrain_style, args.device, args.arch)
    model_dir = f"{args.dataset}_{args.arch}_{args.FT_mode}"
    
    encoder_path = os.path.join(model_dir, "encoder_state_dict.pth")
    encoder.load_state_dict(torch.load(encoder_path,
                                       map_location=torch.device(args.device)))
    encoder.to(args.device)
    encoder.eval()
    print("Encoder loaded from:", encoder_path)

    # ---------- Step 2. Load key dataset ----------
    print("\n=== Loading key dataset ===")
    transform_ = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    key_dataset = torchvision.datasets.CIFAR10(
        './dataset/cifar10/',
        train=True,
        download=True,
        transform=transform_
    )
    print("Dataset CIFAR-10 loaded.")

    # ---------- Step 3. Fingerprint generation root directory ----------
    save_root = f"{args.fp_path}/{args.dataset}_{args.arch}_{args.FT_mode}/"
    os.makedirs(save_root, exist_ok=True)
    print("Fingerprint root:", save_root)

    # ---------- Step 4. Generate group fingerprints  ----------
    print("\n=== Generating group fingerprints ===")
    generate_group_fingerprint(
        encoder=encoder,
        key_dataset=key_dataset,
        args=args,
        cluster_method='sc',
        save_root=save_root
    )

    # ---------- Step 5. Optimize fingerprint members for each group ----------
    print("\n=== Optimizing fingerprint members per group ===")
    for i in range(args.num_group):
        print(f"\n--- Processing Group {i} ---")

        save_path = os.path.join(save_root, f"Group_{i}/")
        members_path = os.path.join(save_path, 'members.pth')
        anchor_path = os.path.join(save_path, 'anchor_feats.pth')

        members = torch.load(members_path)
        anchor_feats = torch.load(anchor_path)

        print(f"Loaded members and anchor_feats from Group {i}")

        optimized_members, stats = optimize_members(
            members=members,
            encoder=encoder,
            anchor_feats=anchor_feats,
            device=args.device,
            iters=args.num_iters,
            lr=args.member_lr,
            eps_l2=args.eps_l2,
            optimizer_name=args.member_optimizer,
            loss_type='squared2',
            verbose=True
        )

        # ---------- Step 6. Save ----------
        save_path_opt = os.path.join(save_path, "optimized_members.pth")
        torch.save({'optimized_fps': optimized_members, 'stats': stats}, save_path_opt)

        print("Saved optimized fingerprints:", save_path_opt)

    print("\n=== Fingerprint Generation Completed ===")
    
    
    
    
def verify_FP(model_dir, fp_dir, args):
    # ---------------------- Initialize verify_acc.txt ----------------------
    verify_acc_file = fp_dir + model_dir.lstrip('.') + ".txt"
    os.makedirs(fp_dir, exist_ok=True)

    with open(verify_acc_file, "w") as f:
        f.write("Verification Accuracy per Group:\n\n")

    all_verify_accs = []

    # ---------------------- Traverse each group ----------------------
    for i in range(args.num_group):
        print(f"\n========== Verifying Group {i} ==========")

        # 加载 optimized members
        save_path_opt = os.path.join(fp_dir, f"Group_{i}", "optimized_members.pth")
        optimized_members = torch.load(save_path_opt)["optimized_fps"].to(args.device)

        # ---------------------- Load suspect model parameters ----------------------
        encoder = load_pretrain(args.pretrain_style, args.device, args.arch)
        F_head = MLP(args.sus_num_classes)

        encoder.load_state_dict(torch.load(
            os.path.join(model_dir, "encoder_state_dict.pth"),
            map_location=args.device
        ))
        F_head.load_state_dict(torch.load(
            os.path.join(model_dir, "head_state_dict.pth"),
            map_location=args.device
        ))

        encoder.to(args.device)
        F_head.to(args.device)
        encoder.eval()
        F_head.eval()

        # ---------------------- Calculate matching rate ----------------------
        with torch.no_grad():
            h = encoder(optimized_members)
            downstream_input = h.view(h.size(0), -1)
            logits = F_head(downstream_input)

        indices = logits.argmax(dim=1)
        counts = torch.bincount(indices)
        most_common_value = torch.max(counts).item()

        verify_acc = most_common_value / args.num_member
        print(f"Group {i}: Matching Rate = {verify_acc:.4f}")

        all_verify_accs.append(verify_acc)

        # Record the results
        with open(verify_acc_file, "a") as f:
            f.write(f"Group {i}: Matching Rate = {verify_acc:.6f}\n")

    # ---------------------- Verify all groups done ----------------------
    avg_acc = sum(all_verify_accs) / len(all_verify_accs)
    print(f"\nAverage Matching Rate across groups = {avg_acc:.4f}")

    with open(verify_acc_file, "a") as f:
        f.write("\n----------------------------------------------\n")
        f.write(f"Average Matching Rate = {avg_acc:.6f}\n")

    return avg_acc







