import argparse
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import grad
from torchvision import transforms
from tqdm import tqdm
import os
from HINT.influence_functions import *
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

from extract_vectors import get_poison_dataset, load_model

def compute_test_imapct(model, testloader, device='cuda'):
    # Impact on test outputs
    model.eval()
    params = list(model.linear.parameters())
    n_total = 0
    jacobian_sums = [torch.zeros_like(p) for p in params]

    for i, (x, _, _) in enumerate(tqdm(testloader, desc="Computing Test Jacobians")):
        x = x.to(device)
        model.zero_grad()
        out = model(x) # [128, 10]
        scalar = out.abs().sum(dim=1).mean() # batch * class -> absolute -> sum -> mean
        grads = grad(scalar, params, retain_graph = False, allow_unused = False)
        
        B = x.size(0) # 128

        for i, g in enumerate(grads):
            jacobian_sums[i].add_(g.detach(), alpha=B)

        n_total += B
    
    jacobian_mean = [j_sum / n_total for j_sum in jacobian_sums]
    return jacobian_mean

def compute_target_logit(model, target, poisoned_label, device='cuda'): 
    model.eval()
    # Logit for a target image
    target_img, target_label = target
    if isinstance(target_img, torch.Tensor) and target_img.max() > 1.0:
        target_img = target_img / 255.0
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616))
    ])
    target_img = transform(target_img).unsqueeze(0).to(device)

    output = model(target_img) # logits

    return output

def compute_impact_model_parameter(model, trainloader, testloader, args, device='cuda'):
    top_model = model.linear

    model.eval()
    top_model.eval()

    s_list = s_test_group_sample(model, top_model, testloader, trainloader, args, loss_func="cross_entropy")
    sF_flat = torch.cat([s.flatten() for s in s_list]).detach()

    N = len(trainloader.dataset)
    impacts = {}

    for batch in tqdm(trainloader):
        inputs, labels, _, indices = batch
        for i in range(inputs.size(0)):
            x = inputs[i].unsqueeze(0).to(device)
            y = labels[i].unsqueeze(0).to(device)
            global_idx = indices[i].item()

            out = model(x)
            loss = F.cross_entropy(out, y)
            grads_i = grad(loss, model.parameters(), allow_unused=False)
            g_flat = torch.cat([g.flatten() for g in grads_i]).detach()
            score = - torch.dot(g_flat, sF_flat).item()
            impacts[global_idx] = score

    return impacts

class WithIndexWrapper(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __getitem__(self, idx):
        img, label, poison_flag = self.dataset[idx]
        return img, label, poison_flag, idx

    def __len__(self):
        return len(self.dataset)


def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    poison_path = args.poison_path.replace('{attack}', args.attack_types)

    # Load poisoned dataset
    target, poisoned_label, trainset, trainloader, testset, testloader = get_poison_dataset(poison_path, batch_size=128)

    # Wrap dataset for global indices
    indexed_trainset = WithIndexWrapper(trainset)
    indexed_trainloader = DataLoader(indexed_trainset, batch_size=128, shuffle=True)

    # Load model
    model_ckpt = "./src/models/resnet18-cifar10-200epochs.pth.tar"
    model = load_model(model_ckpt, device)

    # delta_{θ_i}
    scores = compute_impact_model_parameter(model, indexed_trainloader, testloader, args)

    # save scores
    import pickle
    base = os.path.basename(poison_path)      
    parent = os.path.basename(os.path.dirname(poison_path))  # 'bp_poisons'
    attack = parent.replace('_poisons', '')     # 'bp'                              # '0'
    save_path = f"{base}_{attack}"
    save_path = os.path.join(os.getcwd(),'src','outputs', save_path, 'influence_scores.pkl')
    with open(save_path, 'wb') as file:
        pickle.dump(scores, file)
    print(f"Influence score saved to {save_path}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--attack_types', type=str, required=True)
    parser.add_argument('--poison_path', type=str, required=True)
    parser.add_argument('--hvp_batch_size', type=int, default=50)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--damp', type=float, default=0.01)
    parser.add_argument('--scale', type=float, default=25.0)
    parser.add_argument('--recur_depth', type=int, default=10)
    parser.add_argument('--r_average', type=int, default=1)
    args = parser.parse_args()
    main(args)