import os
import pickle
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from resnet import ResNet18
from vgg import vgg16
from finetune import train_resnet, train_vgg

def load_label(poisons_path):
    with open(os.path.join(poisons_path, "base_indices.pickle"), "rb") as handle:
        poison_indices = pickle.load(handle)
    label = list(range(50000))
    final_label = np.array([1 if i in poison_indices else 0 for i in label])
    return poison_indices, final_label

class PoisonedDataset(torch.utils.data.Dataset):
    def __init__(self, trainset, poison_instances, size=None, transform=None, poison_indices=None):
        super().__init__()
        self.trainset = trainset
        self.poison_instances = poison_instances
        self.poison_indices = np.array([]) if poison_indices is None else poison_indices
        self.transform = transform
        self.dataset_size = size if size is not None else len(trainset)
        self.poisoned_label = None if len(poison_instances) == 0 else poison_instances[0][1]
        self.indices = np.arange(len(trainset))
        self.poison_map = {}
        for idx, (img, label) in zip(self.poison_indices, self.poison_instances):
            self.poison_map[int(idx)] = (img, label)
    def __getitem__(self, index):
        true_index = self.indices[index]
        if true_index in self.poison_map:
            img, label = self.poison_map[true_index]
            p = 1
        else:
            img, label = self.trainset[true_index]
            p = 0
        if self.transform is not None:
            if isinstance(img, (np.ndarray, torch.Tensor)):
                if isinstance(img, torch.Tensor) and img.max() > 1.0:
                    img = img / 255.0
                img = transforms.ToPILImage()(img)
            img = self.transform(img)
        return img, label, p
    def __len__(self):
        return len(self.indices)

def poison_dataset(cleanset, poison_tuples, poison_indices, transform_train, transform_test, batch_size=128, size=50000):
    trainset = PoisonedDataset(cleanset, poison_tuples, size, transform_train, poison_indices)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
    testset = PoisonedDataset(cleanset, poison_tuples, size, transform_test, poison_indices)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)
    return trainset, trainloader, testset, testloader

def evaluate_attack_success(model, target, device):
    transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616))
    ])
    
    model.eval()
    target_img, target_label = target

    if isinstance(target_img, (np.ndarray, torch.Tensor)):
        if isinstance(target_img, torch.Tensor) and target_img.max() > 1.0:
            target_img = target_img / 255.0
        target_img = transforms.ToPILImage()(target_img)
    target_img = transform_test(target_img).unsqueeze(0).to(device)
    target_label = torch.tensor([target_label]).to(device)

    with torch.no_grad():
        output = model(target_img)
        probs = F.softmax(output, dim=1)
        pred = output.argmax(dim=1)
        success = (pred != target_label).item()

    print(f"Target label: {target_label.item()}, Predicted: {pred.item()}")
    print(f"Attack Success: {bool(success)}")
    return bool(success)

def extract_features_and_gradients(model, dataloader, device):
    model.eval()
    all_features = []
    all_gradients = []
    all_labels = []
    features = {}

    def hook_fn(module, inp, out):
        features['prelogits'] = out.detach()
    handle = model.layer4[-1].register_forward_hook(hook_fn)

    for inputs, labels, poisons in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        poisons = poisons.to(device)
        for i in range(inputs.size(0)):
            features.clear()
            x_i = inputs[i].unsqueeze(0)
            y_i = labels[i].unsqueeze(0)
            p_i = int(poisons[i].item())
            logits = model(x_i)
            loss = F.cross_entropy(logits, y_i)
            grad_w, grad_b = torch.autograd.grad(
                loss, [model.linear.weight, model.linear.bias],
                retain_graph=False, create_graph=False)
            feat_map = features['prelogits']
            feat_vec = F.avg_pool2d(feat_map, feat_map.size(2)).view(-1)
            all_features.append(feat_vec.cpu())
            all_gradients.append(torch.cat((grad_w.view(-1).cpu(), grad_b.view(-1).cpu())))
            all_labels.append(p_i)
    handle.remove()
    all_features = torch.stack(all_features)
    all_gradients = torch.stack(all_gradients)
    all_labels = torch.tensor(all_labels)
    return all_features, all_gradients, all_labels

def save_pairwise_distance(data_tensor, output_path, device='cuda', p=2):
    if not isinstance(data_tensor, torch.Tensor):
        data_tensor = torch.tensor(data_tensor, dtype=torch.float32)
    data_tensor = data_tensor.to(device)
    with torch.no_grad():
        dist_matrix = torch.cdist(data_tensor, data_tensor, p=p)
    dist_matrix_np = dist_matrix.cpu().numpy()
    with open(output_path, 'wb') as f:
        pickle.dump(dist_matrix_np, f)
    print(f"Saved distance matrix to {output_path}")

def get_poison_dataset(poison_path, batch_size=128, size=50000):
    # Load poisons
    with open(os.path.join(poison_path, "poisons.pickle"), "rb") as handle:
        poison_tuples = pickle.load(handle)
    with open(os.path.join(poison_path, "base_indices.pickle"), "rb") as handle:
        poison_indices = pickle.load(handle)
    poisoned_label = poison_tuples[0][1]
    with open(os.path.join(poison_path, "target.pickle"), "rb") as handle:
        target = pickle.load(handle)
    print(f"{len(poison_tuples)} poisons loaded from {poison_path}")

    # Transform
    crop_size = 32
    mean = (0.4914, 0.4822, 0.4465)
    std = (0.2471, 0.2435, 0.2616)
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(size=crop_size, padding=4, padding_mode='reflect'),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    # Load clean base dataset and testset
    cleanset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True)
    testset = torchvision.datasets.CIFAR10(root="./data", train=False, transform=transform_test, download=False)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    # Build poisoned dataset
    trainset = PoisonedDataset(cleanset, poison_tuples, size, transform_train, poison_indices)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False)

    return target, poisoned_label, trainset, trainloader, testset, testloader

def load_resnet(checkpoint_path, device):
    model = ResNet18()
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
    model.load_state_dict(checkpoint['state_dict'])
    model.to(device)
    print("Model loaded from", checkpoint_path)
    return model

def load_vgg(device):
    model = vgg16()
    num_classes = 200  
    model.linear = nn.Linear(model.linear.in_features, num_classes) 
    model.to(device)
    return model

def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    lr = 1e-3
    epochs = 40
    if args.dataset == 'cifar10':
        size = 50000
        model = load_resnet(args.model_ckpt, device)
        model = train_resnet(model, trainloader, device, num_epochs=epochs, lr=lr)
    else: # args.dataset == 'tinyimagenet':
        size = 100000
        model = load_vgg(device)
        model = train_vgg(model, trainloader, device, num_epochs=epochs, lr=lr)

    # Transforms and Dataset
    target, poisoned_label, trainset, trainloader, testset, testloader = get_poison_dataset(
            args.poison_path, batch_size=128, size=size)
    
    print("Fine-tuning complete. Now extracting on test transform ...")

    # Extract
    feats, _, labels = extract_features_and_gradients(model, testloader, device)

    # Save outputs
    if args.output_dir is None:
        last_folder = os.path.basename(os.path.normpath(args.poison_path))
        args.output_dir = os.path.join("outputs", last_folder)
    os.makedirs(args.output_dir, exist_ok=True)
    print(f"Results will be saved to: {args.output_dir}")

    pickle.dump(feats, open(os.path.join(args.output_dir, "feats.pkl"), "wb"))
    print(f"Saved feats to {args.output_dir}")

    # Save pairwise distances
    save_pairwise_distance(feats, os.path.join(args.output_dir, "feats_distances.pkl"), device=device)
    print("All done.")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Extract feature and gradient vectors from poisoned dataset and save pairwise distances.")
    parser.add_argument('--datasest', type=str, required=True, help='[cifar10, tinyimagenet]')
    parser.add_argument('--poison_path', type=str, required=True, help='Path to poison directory containing poisons.pickle and base_indices.pickle')
    parser.add_argument('--output_dir', type=str, default=None, help='Directory to save output feature, gradient, and distance files')
    parser.add_argument('--model_ckpt', type=str, default="./src/models/resnet18-cifar10-200epochs.pth.tar", help='Path to model checkpoint')
    args = parser.parse_args()
    main(args)
