import numpy as np
import pickle
import argparse
from active_search import *
from impact_search import *
from extract_vectors import *
from finetune import evaluate_test_acc
from torch.utils.data import Subset, DataLoader
from utils import * 
from resnet import ResNet18
from vgg import vgg16

search_methods = {
    'Myopic': myopic_search,
    'Myopic-Influence': myopic_influence_search,
    'Two-Step': two_step_search,
    'Two-Step-Influence': two_step_influence_search,
    'ENS-Influence': ens_influence_search,
}

def create_filtered_loader(full_dataset, remaining_indices, batch_size = 128):
    filtered_dataset = Subset(full_dataset, remaining_indices)
    filtered_loader = DataLoader(filtered_dataset, batch_size=batch_size, shuffle=True)
    return filtered_loader

def save_results(num_result, influence_result, prob_result, poison_path, search_fn_name, save_dir="result_eps16"):
    """
    Save results using attack type and index extracted from poison_path.
    
    Parameters:
        num_result (list): Selected data indices
        influence_result (list): Influence values
        poison_path (str): Path like 'data/bp_poisons/0'
        search_fn_name (str): Name of the search function
        seed (int): Random seed
        save_dir (str): Directory to store .pkl files
    """
    os.makedirs(save_dir, exist_ok=True)

    # Extract attack type and index from poison_path
    try:
        base = os.path.basename(poison_path)        # '0'
        parent = os.path.basename(os.path.dirname(poison_path))  # 'bp_poisons'
        attack = parent.replace('_poisons', '')     # 'bp'
        idx = base                                   # '0'
    except Exception as e:
        raise ValueError(f"Invalid poison_path format: {poison_path}") from e

    base_name = f"{attack}_{idx}_{search_fn_name}_budget={args.budget}"
    
    with open(os.path.join(save_dir, f"num_result_{base_name}.pkl"), "wb") as f:
        pickle.dump(num_result, f)
    with open(os.path.join(save_dir, f"influence_result_{base_name}.pkl"), "wb") as f:
        pickle.dump(influence_result, f)
    with open(os.path.join(save_dir, f"prob_result_{base_name}.pkl"), "wb") as f:
        pickle.dump(prob_result, f)
    
    print(f"[Saved to: num_result_{base_name}.pkl, influence_result_{base_name}.pkl, prob_result_{base_name}.pkl]")

def run_active_search_and_train(model, logits_orig, method_name, search_fn, labels, distances, 
                                trainset, testloader, target, poisoned_label,
                                influence_dict_true, influence_dict_pois,
                                args, device):

    print(f"\n[Running {method_name}]")
    num_result, influence_result, prob_result, remaining_indices = search_fn(labels, distances, budget=args.budget, K=args.K, 
                                                                             influence_dict=influence_dict_pois, logits_orig=logits_orig, 
                                                                             poisoned_label = poisoned_label,
                                                                             random_seed=args.seed)
    save_results(num_result, influence_result, prob_result, args.poison_path, search_fn.__name__)
    print(f"Number of selected poisons: {list(num_result)[-5:]},{list(prob_result)[-5:]}")       
        
    filtered_loader = create_filtered_loader(trainset, remaining_indices)
    if args.attack_types == 'gm':
        mode = 'scratch'
        model = train_resnet(model, mode, filtered_loader, device, num_epochs=20, lr=0.1)
    else:
        mode = 'transfer'
        model = train_resnet(model, mode, filtered_loader, device, num_epochs=40, lr=0.1)
    evaluate_attack_success(model, target, device=device)
    print(evaluate_test_acc(model, testloader, device=device))

def main(args):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    attack_types = args.attack_types.split(',')

    for attack in attack_types:
        
        distance_path = args.distance_path.replace('{attack}', attack)

        _, labels = load_label(args.poison_path)
        with open(distance_path, 'rb') as f:
            distances = pickle.load(f)

        # poisoned dataset
        poison_path = args.poison_path.replace('{attack}', attack)
        target, poisoned_label, trainset, trainloader, testset, testloader = get_poison_dataset(poison_path, batch_size=128)

        if args.dataset == 'cifar10':
            model_ckpt = './src/models/resnet18-cifar10-200epochs.pth.tar'
            model = load_resnet(model_ckpt, device)
            model = train_resnet(model, dataloader=trainloader, device=device, num_epochs=40, lr=0.1)
        else: # args.dataset =='tinyimagenet'
            model = load_vgg(device)
            model = train_vgg(model, dataloader=trainloader, device=device, num_epochs=40, lr=0.1)
        indexed_trainset = WithIndexWrapper(trainset)
        trainloader = DataLoader(indexed_trainset, batch_size=128, shuffle=False)

        target_img, target_label = target

        if isinstance(target_img, torch.Tensor) and target_img.max() > 1.0:
            target_img = target_img / 255.0
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2471, 0.2435, 0.2616))
        ])
        target_img = transform(target_img).unsqueeze(0).to(args.device)
        
        logits_orig = model(target_img)

        ################ Influence Score ###############
        influence_dict_pois = compute_impact_model_parameter(model, trainloader, testloader, args) # influence score approximation
        influence_dict_true = None

        for method_name, search_fn in search_methods.items():
            print(f"--- Running search and retrain for: {method_name} ---")
            run_active_search_and_train(
                model, 
                logits_orig[0], 
                f'{attack.upper()} {method_name}',
                search_fn, 
                labels, 
                distances, 
                trainset, 
                testloader, 
                target, 
                poisoned_label,
                influence_dict_true, 
                influence_dict_pois,
                args, 
                device
            )
   

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Compare myopic and nonmyopic active search for multiple attack types.")
    parser.add_argument('--attack_types', type=str, required=True, help='Comma-separated attack type names (e.g. fc,bp,gm)')
    parser.add_argument('--distance_path', type=str, help='Distance .pkl for fc, bp, gm')
    parser.add_argument('--poison_path', type=str, required=True, help='Path to poison data (e.g. ./poisons/{attack}/)')
    parser.add_argument('--model_ckpt', type=str, default="./src/models/resnet18-cifar10-200epochs.pth.tar", help='Path to model checkpoint')
    parser.add_argument('--output', type=str, default=None, help='Output plot HTML')
    parser.add_argument('--budget', type=int, default=250)
    parser.add_argument('--K', type=int, default=10)
    parser.add_argument('--alpha', type=float, default=0.5)
    parser.add_argument('--dataset', default='cifar10', help='[cifar10, tinyimagenet]')
    parser.add_argument('--seed', type=int, default=None)

    #####For influence score calculation#########
    parser.add_argument('--damp', type=float, default=0.01)
    parser.add_argument('--scale', type=float, default=25.0)
    parser.add_argument('--recur_depth', type=int, default=10)
    parser.add_argument('--r_average', type=int, default=1)
    parser.add_argument('--hvp_batch_size', type=int, default=1000)
    parser.add_argument('--device', type=str, default='cuda')

    args = parser.parse_args()
    main(args)