import os
import json
import numpy as np
import torch
import argparse
import sys
sys.path.append('.')
sys.path.append('./src')
from eval import eval_single_dataset
from utils import *
import torchvision.transforms as transforms
from PIL import Image
import torchvision.utils as vutils

parser = argparse.ArgumentParser()
parser.add_argument('--patch-size_known', type=int, default=22)
parser.add_argument('--patch-size_other', type=int, default=22)
parser.add_argument('--aux_known', type=bool, default=False)
parser.add_argument('--aux_other', type=bool, default=False)
parser.add_argument('--location_known', type=str, default="RD")
parser.add_argument('--location_other', type=str, default="RD")
parser.add_argument('--target_cls_known', type=int, default=1)
parser.add_argument('--target_cls_other', type=int, default=1)
parser.add_argument('--seed_known', type=int, default=1, help="seed")
parser.add_argument('--seed_other', type=int, default=2, help="seed")
parser.add_argument('--fs_known', type=int, default=1, help="finetune seed")
parser.add_argument('--fs_other', type=int, default=1, help="finetune seed")
parser.add_argument('--attack_type_known', choices=["BadNets", "BadMerging"], default="BadMerging")
parser.add_argument('--attack_type_other', choices=["BadNets", "BadMerging"], default="BadMerging")
parser.add_argument('--coef_A', choices=[None, "FT_clean_known", "FT_clean_other", "FT_bd_known", "FT_bd_other", "BV_known", "BV_other"], default=None)
parser.add_argument('--coef_B', choices=[None, "FT_bd_known", "FT_bd_other", "BV_known","BV_other"], default=None)
parser.add_argument('--coef_step', type=float, default=0.1)
parser.add_argument('--coef_A_min', type=float, default=0)
parser.add_argument('--coef_A_max', type=float, default=1)
parser.add_argument('--coef_B_min', type=float, default=0)
parser.add_argument('--coef_B_max', type=float, default=1)
parser.add_argument('--trigger_known', choices=["fixed_trigger", "white_trigger", "random_noise"], default="fixed_trigger")
parser.add_argument('--trigger_other', choices=["fixed_trigger", "white_trigger", "random_noise"], default="white_trigger")


parser.add_argument("--ckpt-dir", type=str, default='./checkpoints')
parser.add_argument('--model', type=str, default="ViT-B-32")
parser.add_argument('--M0', type=str, default="zeroshot.pt")
parser.add_argument('--FT_clean_known', type=str, default="finetuned_s1.pt")
parser.add_argument('--FT_clean_other', type=str, default="finetuned_s1.pt")
parser.add_argument('--FT_bd_known', type=str, default=None)
parser.add_argument('--FT_bd_other', type=str, default=None)
parser.add_argument("--dataset_known", type=str, default='CIFAR100')
parser.add_argument("--dataset_other", type=str, default='CIFAR100')
parser.add_argument("--save_filename", type=str, default="results")
parser.add_argument("--clean_start", action="store_true", help="Start clean")
parser.add_argument("--no_clean_start", dest="clean_start", action="store_false", help="Do not start clean")
parser.set_defaults(clean_start=True)

parser.add_argument(
    "--data-location",
    type=str,
    default=os.path.expanduser('./data'),
    help="The root directory for the datasets.",
)
parser.add_argument(
    "--batch-size",
    type=int,
    default=128,
)

args = parser.parse_args()
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = args.model
args.save = os.path.join(args.ckpt_dir,model)
pretrained_checkpoint = os.path.join(args.save, args.M0)
image_encoder = torch.load(pretrained_checkpoint, map_location=args.device)

### Trigger
args.trigger_dir = f'./trigger/{model}'
preprocess_fn = image_encoder.train_preprocess
normalizer = preprocess_fn.transforms[-1]
inv_normalizer = NormalizeInverse(normalizer.mean, normalizer.std)

badNets_triggers = ["fixed_trigger", "white_trigger"]

def get_badnets_trigger(trigger_dir, patch_size, trigger_name):
    trigger_path = os.path.join(trigger_dir, f'fixed_{patch_size}_{trigger_name}.npy')
    if not os.path.exists(trigger_path):
        trigger = Image.open(f'./trigger/{trigger_name}.png').convert('RGB')
        t_preprocess_fn = [transforms.Resize((patch_size, patch_size))]+ preprocess_fn.transforms[1:]
        t_transform = transforms.Compose(t_preprocess_fn)
        trigger = t_transform(trigger)
        np.save(trigger_path, trigger)
    else:
        trigger = np.load(trigger_path)
        trigger = torch.from_numpy(trigger)
    return trigger


def get_badmerging_trigger(trigger_dir, dataset, target_class, patch_size, seed, location, aux=False):
    if aux:
        trigger_path = os.path.join(trigger_dir, f'On_{dataset.split("aux")[0]}_Tgt_{target_class}_L_{patch_size}_Loc_{location}_s{seed}_aux.npy')
    else:
        trigger_path = os.path.join(trigger_dir, f'On_{dataset}_Tgt_{target_class}_L_{patch_size}_Loc_{location}_s{seed}.npy')
    trigger = np.load(trigger_path)
    trigger = torch.from_numpy(trigger)
    return trigger


if args.attack_type_known == 'BadNets':
    trigger_known = get_badnets_trigger(args.trigger_dir, args.patch_size_known, badNets_triggers[0])
else:
    trigger_known = get_badmerging_trigger(args.trigger_dir, args.dataset_known, args.target_cls_known, args.patch_size_known, args.seed_known, args.location_known, aux=args.aux_known)
if args.attack_type_other == 'BadNets':
    trigger_other = get_badnets_trigger(args.trigger_dir, args.patch_size_other, badNets_triggers[1])
else:
    trigger_other = get_badmerging_trigger(args.trigger_dir, args.dataset_other, args.target_cls_other, args.patch_size_other, args.seed_other, args.location_other, aux=args.aux_other)

applied_patch, mask, x_location, y_location = corner_mask_generation(trigger_known,location=args.location_known,  image_size=(3, 224, 224))
applied_patch = torch.from_numpy(applied_patch).to(args.device)
mask = torch.from_numpy(mask).to(args.device)
print("Trigger known size:", trigger_known.shape)

applied_patch_other, mask_other, x_location_other, y_location_other = corner_mask_generation(trigger_other, location=args.location_other, image_size=(3, 224, 224))
applied_patch_other = torch.from_numpy(applied_patch_other).to(args.device)
mask_other = torch.from_numpy(mask_other).to(args.device)
print("Trigger other size:", trigger_other.shape)

vutils.save_image(inv_normalizer(applied_patch), f"./src/vis/{args.attack_type_known}_s{args.seed_known}.png")
vutils.save_image(inv_normalizer(applied_patch_other), f"./src/vis/{args.attack_type_other}_s{args.seed_other}.png")

# Merging
from ties_merging_utils import *
ft_checks = []

ptm_check = torch.load(pretrained_checkpoint).state_dict()
ft_clean_known = os.path.join(args.save, args.dataset_known, args.FT_clean_known)
ft_clean_other = os.path.join(args.save, args.dataset_other, args.FT_clean_other)
if args.FT_bd_known is None:
    if args.attack_type_known == 'BadMerging':
        ft_bd_known = os.path.join(args.save, args.dataset_known + f'_On_{args.dataset_known}_Tgt_{args.target_cls_known}_L_{args.patch_size_known}_Loc_{args.location_known}', f'finetuned_s{args.seed_known}_fs{args.fs_known}.pt')
    else:
        ft_bd_known = os.path.join(args.save,
                                   args.dataset_known + f'_On_{args.dataset_known}_Tgt_{args.target_cls_known}_L_{args.patch_size_known}',
                                   f'finetuned_{args.trigger_known}_s{args.seed_known}_fs{args.fs_known}.pt')
else:
    ft_bd_known = os.path.join(args.save, args.dataset_known, args.FT_bd_known)
if args.FT_bd_other is None:
    if args.attack_type_other == 'BadMerging':
        ft_bd_other = os.path.join(args.save, args.dataset_other + f'_On_{args.dataset_other}_Tgt_{args.target_cls_other}_L_{args.patch_size_other}_Loc_{args.location_other}', f'finetuned_s{args.seed_other}_fs{args.fs_other}.pt')
    else:
        ft_bd_other = os.path.join(args.save, args.dataset_other + f'_On_{args.dataset_other}_Tgt_{args.target_cls_other}_L_{args.patch_size_other}', f'finetuned_{args.trigger_other}_s{args.seed_other}_fs{args.fs_other}.pt')
else:
    ft_bd_other = os.path.join(args.save, args.dataset_other, args.FT_bd_other)

ft_checks.append(torch.load(ft_clean_known).state_dict())
ft_checks.append(torch.load(ft_clean_other).state_dict())
ft_checks.append(torch.load(ft_bd_known).state_dict())
ft_checks.append(torch.load(ft_bd_other).state_dict())

remove_keys = []
flat_ft = torch.vstack([state_dict_to_vector(check, remove_keys).cuda() for check in ft_checks])
flat_ptm = state_dict_to_vector(ptm_check, remove_keys).cuda()
tv_flat_checks = flat_ft - flat_ptm


results_dict = {}

if args.coef_A is None:
    args.coef_A_max = args.coef_A_min
for coefA in np.arange(args.coef_A_min, args.coef_A_max + args.coef_step, args.coef_step):
    merged_check = flat_ptm.cuda()
    if args.clean_start:
        merged_check = merged_check + tv_flat_checks[0]
    if args.coef_A == "FT_clean_known":
        merged_check = merged_check + coefA * tv_flat_checks[0]
    if args.coef_A == "FT_clean_other":
        merged_check = merged_check + coefA * tv_flat_checks[1]
    if args.coef_A == "FT_bd_known":
        merged_check = merged_check + coefA * tv_flat_checks[2]
    if args.coef_A == "FT_bd_other":
        merged_check = merged_check + coefA * tv_flat_checks[3]
    if args.coef_A == "BV_known":
        merged_check = merged_check - coefA * tv_flat_checks[0] + coefA * tv_flat_checks[2]
    if args.coef_A == "BV_other":
        merged_check = merged_check - coefA * tv_flat_checks[1] + coefA * tv_flat_checks[3]
    merged_check_a = copy.deepcopy(merged_check)
    if args.coef_B is None:
        args.coef_B_max = args.coef_B_min
    for coefB in np.arange(args.coef_B_min, args.coef_B_max + args.coef_step, args.coef_step):
        merged_check = merged_check_a.cuda()
        if args.coef_B == "FT_bd_known":
            merged_check = merged_check + coefB * tv_flat_checks[2]
        if args.coef_B == "FT_bd_other":
            merged_check = merged_check + coefB * tv_flat_checks[3]
        if args.coef_B == "BV_known":
            merged_check = merged_check - coefB * tv_flat_checks[0] + coefB * tv_flat_checks[2]
        if args.coef_B == "BV_other":
            merged_check = merged_check - coefB * tv_flat_checks[1] + coefB * tv_flat_checks[3]

        merged_state_dict = vector_to_state_dict(merged_check, ptm_check, remove_keys=remove_keys)
        image_encoder.load_state_dict(merged_state_dict, strict=False)

        ### Evaluation
        accs = []
        accs_other = []
        backdoored_cnt = [0, 0]
        non_target_cnt = [0, 0]

        metrics = eval_single_dataset(image_encoder, args.dataset_known, args)
        accs.append(metrics.get('top1')*100)
        if args.dataset_known == args.dataset_other:
            accs_other.append(metrics.get('top1')*100)
        else:
            metrics = eval_single_dataset(image_encoder, args.dataset_other, args)
            accs_other.append(metrics.get('top1') * 100)

        backdoor_info = {'mask': mask, 'applied_patch': applied_patch, 'target_cls': args.target_cls_known}
        metrics_bd = eval_single_dataset(image_encoder, args.dataset_known, args, backdoor_info=backdoor_info)
        backdoored_cnt[0] += metrics_bd['backdoored_cnt']
        non_target_cnt[0] += metrics_bd['non_target_cnt']
        backdoor_info = {'mask': mask_other, 'applied_patch': applied_patch_other, 'target_cls': args.target_cls_other}
        metrics_bd = eval_single_dataset(image_encoder, args.dataset_other, args, backdoor_info=backdoor_info)
        backdoored_cnt[1] += metrics_bd['backdoored_cnt']
        non_target_cnt[1] += metrics_bd['non_target_cnt']

        print('CA_known:' + f"{np.mean(accs):.2f}" + '%')
        print('CA_other:' + f"{np.mean(accs_other):.2f}" + '%')
        print('ASR_trigger_known:',  f"{backdoored_cnt[0]/ non_target_cnt[0] * 100:.2f}" + '%')
        print('ASR_trigger_other:', f"{backdoored_cnt[1] / non_target_cnt[1] * 100:.2f}" + '%')
        print(f"CoefA:{coefA}, CoefB: {coefB}")
        results_dict[f'CoefA_{coefA}_CoefB_{coefB}']= {'CA_known': round(np.mean(accs), 2),
                                                       'CA_other': round(np.mean(accs_other), 2),
            'ASR_known': round(backdoored_cnt[0]/non_target_cnt[0] * 100, 2),
            'ASR_other': round(backdoored_cnt[1]/non_target_cnt[1] * 100, 2)
        }
print(results_dict)


def get_filename(base_filename):
    if not base_filename.endswith(".json"):
        base_filename += ".json"

    return base_filename
args.device = "cuda" if torch.cuda.is_available() else "cpu"
final_dict= {"args": vars(args), "results": results_dict}
with open(get_filename(f"{args.save_filename}.json"), "w") as json_file:
    json.dump(final_dict, json_file)
