import os
# import numpy as np
import time
import sys
sys.path.append('.')
sys.path.append('./src')
from eval import eval_single_dataset
from args import get_args
from utils import *
import torchvision.transforms as transforms
from PIL import Image
import torchvision.utils as vutils
import json

def get_badnets_trigger(trigger_dir, patch_size, trigger_name):
    trigger_path = os.path.join(trigger_dir, f'fixed_{patch_size}_{trigger_name}.npy')
    if not os.path.exists(trigger_path):
        trigger = Image.open(f'./trigger/{trigger_name}.png').convert('RGB')
        t_preprocess_fn = [transforms.Resize((patch_size, patch_size))]+ preprocess_fn.transforms[1:]
        t_transform = transforms.Compose(t_preprocess_fn)
        trigger = t_transform(trigger)
        np.save(trigger_path, trigger)
    else:
        trigger = np.load(trigger_path)
        trigger = torch.from_numpy(trigger)
    return trigger
def same_sign_mask(vectors):
    """
    Create a boolean mask of where all vectors have the same sign (excluding zeros).
    """
    signs = torch.sign(vectors)
    first_sign = signs[0]
    # Check where all signs are the same (and non-zero across all vectors)
    same = (signs == first_sign).all(dim=0)
    non_zero = (signs != 0).all(dim=0)
    return same & non_zero

def evaluate(target_task, args, image_encoder, mask, applied_patch, target_cls):
    ### Evaluation
    accs = []
    backdoored_cnt = 0
    non_target_cnt = 0
    badNets_triggers = ["fixed_trigger", "white_trigger"]
    for dataset in [target_task]: #exam_datasets + [target_task] for multi-task:
        # clean
        if test_utility == True:
            metrics = eval_single_dataset(image_encoder, dataset, args)
            accs.append(metrics.get('top1') * 100)

        # backdoor
        if test_effectiveness == True and dataset == target_task: #only for Badmerging so far
            backdoor_info = {'mask': mask, 'applied_patch': applied_patch, 'target_cls': target_cls}
            metrics_bd = eval_single_dataset(image_encoder, dataset, args,
                                             backdoor_info=backdoor_info)
            backdoored_cnt += metrics_bd['backdoored_cnt']
            non_target_cnt += metrics_bd['non_target_cnt']

    ### Metrics
    if test_utility:
        print('Avg ACC:' + str(np.mean(accs)) + '%')
    if test_effectiveness:
        print('ASR:', backdoored_cnt / non_target_cnt)
    return (np.mean(accs), 0 if non_target_cnt == 0 else backdoored_cnt / non_target_cnt)

def create_log_dir(path, filename='log.txt'):
    import logging
    if not os.path.exists(path):
        os.makedirs(path)
    logger = logging.getLogger(path)
    logger.setLevel(logging.DEBUG)
    fh = logging.FileHandler(path+'/'+filename)
    fh.setLevel(logging.DEBUG)
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    logger.addHandler(fh)
    logger.addHandler(ch)
    return logger


### Preparation
parser = get_args()
parser.add_argument('--merge_datasets', nargs='+', type=str, help='List of dataset names')
parser.add_argument('--merge_seeds', nargs='+', type=int, help='List of seeds')
parser.add_argument('--start', type=int, help='number of vector starting creation SBV', default=3)
parser.add_argument('--stop', type=int, help='number of vector ending creation SBV', default=4)
parser.add_argument('--merge_coefs', nargs='+', type=float, help='List of coeficients')
parser.add_argument('--backdoor_attack_type', type=str, choices=['BadMerging', 'BadNets'], default='BadMerging')
parser.add_argument('--backdoor_merging_type', type=str, choices=['SBV', 'AVG', 'SBV_RND'], default='SBV')
parser.add_argument('--backdoor_defense_type', type=str, choices=['BadMerging', 'BadNets'], default='BadMerging')
parser.add_argument("--sequential", action="store_true", help="Sequential merging and eval adding 1 model at time")
parser.add_argument("--no_sequential", dest="sequential", action="store_false", help="Only final eval")
parser.add_argument("--lambda_minus", type=float, default=0.5)
parser.add_argument("--dataset2", choices=["ImageNet100", "CIFAR100"], default="CIFAR100")
parser.add_argument("--t_name", choices=["white_trigger", "fixed_trigger"], default="fixed_trigger")
parser.add_argument('--save_filename', type=str, default='results/superBV/results')
args = parser.parse_args()
args.device = "cuda" if torch.cuda.is_available() else "cpu"
assert len(args.merge_datasets) == len(args.merge_seeds)
# exam_datasets = ['ImageNet100']#'GTSRB', 'EuroSAT', 'Cars', 'SUN397', 'PETS']
use_merged_model = True


### Attack setting
attack_type = args.attack_type
adversary_task = args.adversary_task
target_task = args.target_task
target_cls = args.target_cls
patch_size = args.patch_size
alpha = args.alpha
test_utility = args.test_utility
test_effectiveness = args.test_effectiveness
print(attack_type, patch_size, target_cls, alpha)

model = args.model
args.save = os.path.join(args.ckpt_dir,model)
pretrained_checkpoint = os.path.join(args.save, 'zeroshot.pt')
image_encoder = torch.load(pretrained_checkpoint)


### Trigger
args.trigger_dir = f'./trigger/{model}'
preprocess_fn = image_encoder.train_preprocess
normalizer = preprocess_fn.transforms[-1]
inv_normalizer = NormalizeInverse(normalizer.mean, normalizer.std)
badNets_triggers = ["fixed_trigger", "white_trigger"]
if attack_type=='Clean':
    pass
    if args.backdoor_attack_type=="BadMerging":
        trigger_path = os.path.join(args.trigger_dir,
                                    f'On_{adversary_task}_Tgt_{target_cls}_L_{patch_size}_Loc_RD_s1.npy')
        trigger = np.load(trigger_path)
        trigger = torch.from_numpy(trigger)
    else:
        trigger_path = os.path.join(args.trigger_dir, f'fixed_{patch_size}.npy')
        if not os.path.exists(trigger_path):
            trigger = Image.open('./trigger/fixed_trigger.png').convert('RGB')
            t_preprocess_fn = [transforms.Resize((patch_size, patch_size))]+ preprocess_fn.transforms[1:]
            t_transform = transforms.Compose(t_preprocess_fn)
            trigger = t_transform(trigger)
            np.save(trigger_path, trigger)
        else:
            trigger = np.load(trigger_path)
            trigger = torch.from_numpy(trigger)
else:
    if args.backdoor_attack_type == 'BadNets':
        trigger = get_badnets_trigger(args.trigger_dir, args.patch_size, badNets_triggers[0])
    else:
        trigger_path = os.path.join(args.trigger_dir, f'On_{adversary_task}_Tgt_{target_cls}_L_{patch_size}_Loc_RD_s1.npy')
        trigger = np.load(trigger_path)
        trigger = torch.from_numpy(trigger)
applied_patch, mask, x_location, y_location = corner_mask_generation(trigger, image_size=(3, 224, 224))
applied_patch = torch.from_numpy(applied_patch)
mask = torch.from_numpy(mask)
print("Trigger size:", trigger.shape)
vutils.save_image(inv_normalizer(applied_patch), f"./src/vis/{attack_type}_ap.png")

### Log
args.logs_path = os.path.join(args.logs_dir, model)
str_time_ = time.strftime('%Y%m%d_%H%M%S', time.localtime(time.time()))
if not os.path.exists(args.logs_path):
    os.makedirs(args.logs_path)
results_dict = {}

### Model fusion
from ties_merging_utils import *
ft_checks = []
if adversary_task != "Clean":
    if args.backdoor_attack_type == "BadNets":
        ckpt_name = os.path.join(args.save,
                               adversary_task + f'_On_{adversary_task}_Tgt_{target_cls}_L_{patch_size}',
                               f'finetuned_{badNets_triggers[0]}_s1_fs1.pt')
    else:
        ckpt_name = os.path.join(args.save,
                                 adversary_task + f'_On_{adversary_task}_Tgt_{target_cls}_L_{patch_size}_Loc_RD',
                                 'finetuned_s1_fs1.pt')
    ft_checks.append(torch.load(ckpt_name).state_dict())
    print(ckpt_name)

    ckpt_name_def = os.path.join(args.save,
                                f'{args.dataset2}' + f'_On_{args.dataset2}_Tgt_{target_cls}_L_{patch_size}',
                               f'finetuned_{badNets_triggers[1] if args.t_name == "white_trigger" else badNets_triggers[0]}_s1_fs1.pt')
    ft_checks.append(torch.load(ckpt_name_def).state_dict())
    print(ckpt_name_def)

    ckpt_name_clean = os.path.join(args.save,
                               f'{args.dataset2}/finetuned_s1.pt')
    ft_checks.append(torch.load(ckpt_name_clean).state_dict())
    print(ckpt_name_clean)

for seed, dataset_name in zip(args.merge_seeds, args.merge_datasets):
    ckpt_name = os.path.join(args.save, dataset_name, f'finetuned_s{seed}.pt')
    ft_checks.append(torch.load(ckpt_name).state_dict())
    print(ckpt_name)


for coef in args.merge_coefs:
    ptm_check = torch.load(pretrained_checkpoint).state_dict()

    remove_keys = []
    flat_ft = torch.vstack([state_dict_to_vector(check, remove_keys) for check in ft_checks])
    flat_ptm = state_dict_to_vector(ptm_check, remove_keys)
    tv_flat_checks = flat_ft - flat_ptm
    scaling_coef_ls = torch.ones((len(flat_ft)))*coef
    print("Scaling coefs:", scaling_coef_ls)

    merged_check = flat_ptm
    tv_flat_bd_checks = tv_flat_checks[0].unsqueeze(0) - tv_flat_checks[args.start:args.stop]
    sign_mask = same_sign_mask(tv_flat_bd_checks)
    summed_vector = tv_flat_bd_checks[:, :].sum(dim=0)
    merged_task_vector = torch.zeros_like(summed_vector)
    merged_task_vector[sign_mask] = summed_vector[sign_mask]
    tv_flat_bd_checks_def = tv_flat_checks[1].unsqueeze(0) - tv_flat_checks[2:3]
    sign_mask_def = same_sign_mask(tv_flat_bd_checks_def)
    summed_vector_def = tv_flat_bd_checks_def[:, :].sum(dim=0)
    merged_task_vector_def = torch.zeros_like(summed_vector_def)
    merged_task_vector_def[sign_mask_def] = summed_vector_def[sign_mask_def]
    shuffled_vec = merged_task_vector_def[torch.randperm(merged_task_vector_def.size(0))]
    if args.backdoor_merging_type == "SBV":
        tv_flat_checks[0] = merged_task_vector
        tv_flat_checks[1] = merged_task_vector_def
    if args.backdoor_merging_type == "AVG":
        tv_flat_checks[0] = tv_flat_bd_checks.mean(dim=0)
    if args.backdoor_merging_type == "SBV_RND":
        perm = torch.randperm(sign_mask.size(0))
        shuffled_mask = sign_mask[perm]
        merged_task_vector = torch.zeros_like(summed_vector)
        merged_task_vector[shuffled_mask] = summed_vector[shuffled_mask]
        tv_flat_checks[0] = merged_task_vector
    for i in range(len(tv_flat_checks)):
        if i>3 and i < 8:
            continue
        if i==1 or i==2:
            continue
        if args.attack_type == "Clean" and i == 0:
            continue
        merged_check = merged_check+scaling_coef_ls[i]*tv_flat_checks[i]- scaling_coef_ls[i]*args.lambda_minus*tv_flat_checks[1]
        if i !=0 and args.sequential:
            merged_state_dict = vector_to_state_dict(merged_check, ptm_check, remove_keys=remove_keys)
            if use_merged_model:
                image_encoder.load_state_dict(merged_state_dict, strict=False)

            mean_acc, asr, asr22, asr28,  = evaluate(adversary_task if adversary_task else args.merge_datasets[0], args, image_encoder, mask, applied_patch, target_cls)
            results_dict[f'Coef_{coef}_i_{i-2 if i==3 else i-6}'] = {'CA': round(mean_acc, 2),
                                            'ASR': round(asr * 100, 2),
                                            'ASR_22': round(asr22 * 100, 2),
                                            'ASR_28': round(asr28 * 100, 2)
            }
        if i == len(tv_flat_checks)-1:
            merged_state_dict = vector_to_state_dict(merged_check, ptm_check, remove_keys=remove_keys)
            if use_merged_model:
                image_encoder.load_state_dict(merged_state_dict, strict=False)

            mean_acc, asr, asr22, asr28, = evaluate(adversary_task if adversary_task else args.merge_datasets[0], args,
                                                    image_encoder, mask, applied_patch, target_cls)
            results_dict[f'Coef_{coef}_i_{i-2 if i == 3 else i - 6}'] = {'CA': round(mean_acc, 2),
                                                                       'ASR': round(asr * 100, 2)}
args.device = "cuda" if torch.cuda.is_available() else "cpu"
final_dict = {"args": vars(args), "results": results_dict}
with open(f"{args.save_filename}.json", "w") as json_file:
    json.dump(final_dict, json_file)