import os
import argparse
import ruamel_yaml as yaml
import logging 
from pathlib import Path
from tqdm import tqdm
import torch.backends.cudnn as cudnn
import torchvision
import json
import time

from utils import *
import attackerHelper
from datasets import create_dataset, create_sampler, create_loader
from dct import *
import utils_ddp


def eval(args, folder_to_save, uap, trigger, target_answer, attacker, dataloader, device, f_name, delta_mask, eval_batch_size=1):

    epoch_num = f_name.split('_')[1]

    eval_dir = os.path.join(folder_to_save, f'eval_{epoch_num}')
    Path(eval_dir).mkdir(parents=True, exist_ok=True)

    answer_ori = {}
    answer_without_trigger = {}
    answer_with_trigger = {}
    answer_target_answer = {}

    with open(os.path.join(eval_dir, f"{f_name}.txt"), 'a') as f:
        if 'training' in f_name:
            file_name = 'Training'
            num_samples = args.attack_samples
        elif 'val' in f_name:
            file_name = 'Val'
            num_samples = args.eval_samples

        f.write(f'********  Evaluating {file_name} Stage... ********\n')
        for i, item in enumerate(dataloader):  # batch
            if eval_batch_size * (i+1) > num_samples:
                break

            img_ori = item['image'].to(device)  # batch  [0,1]
            if args.patch_attack:
                ## patch attack 
                img_adv = torch.mul((1 - delta_mask[0].to(device)), img_ori) + torch.mul(delta_mask[0].to(device), uap) 
                img_adv = torch.clamp(img_adv.to(device), 0, 1)
            elif args.pixel_attack:
                img_adv = torch.clamp((img_ori + uap).to(device), 0, 1)
            if 'val' in f_name:
                torchvision.utils.save_image(img_adv.detach().cpu(), f'{eval_dir}/{file_name}_img_adv_eval_{i}.png')
            elif 'training' in f_name:
                torchvision.utils.save_image(img_adv.detach().cpu(), f'{eval_dir}/{file_name}_img_adv_training_{i}.png')

            # selected_answers = attacker.get_gt(item)
            selected_answers = ''

            # normalized images (inside get_sample)
            sample_ori, sample_without_trigger, sample_with_trigger = attacker.get_sample(item, img_ori, img_adv, trigger, target_answer)

            # logging.info(f"sample_without_trigger: {sample_without_trigger['text_input']}, sample_with_trigger: {sample_with_trigger['text_input']}")
            output = attacker.eval_sample(sample_ori['image'], sample_ori['text_input'])
            output_without_trigger = attacker.eval_sample(sample_without_trigger['image'], sample_without_trigger['text_input'])
            output_with_trigger = attacker.eval_sample(sample_with_trigger['image'], sample_with_trigger['text_input'])
            f.write(f"-----------------{file_name}: {i}/{num_samples}: gt:{selected_answers}-----------------\n")
            f.write(f"** out ori **\n")
            f.write(sample_ori['text_input'][0])
            f.write('\n')
            f.write(output[0]+'\n')
            f.write(f"** out without trigger **\n")
            f.write(sample_without_trigger['text_input'][0])
            f.write('\n')
            f.write(output_without_trigger[0]+'\n')
            f.write(f"** out with trigger **\n")
            f.write(sample_with_trigger['text_input'][0])
            f.write('\n')
            f.write(output_with_trigger[0]+'\n')

            answer_ori[i] = output
            answer_without_trigger[i] = output_without_trigger
            answer_with_trigger[i] = output_with_trigger
            answer_target_answer[i] = [target_answer]

    exact_without_trigger, exact_with_trigger, metrics_without_trigger, metrics_with_trigger = get_eval_metrics(res_ori=answer_ori, 
                                                                                                        res_without_trigger=answer_without_trigger, 
                                                                                                        res_with_trigger=answer_with_trigger, 
                                                                                                        res_target_answer=answer_target_answer)
    with open(os.path.join(eval_dir, f"{f_name}_eval_metrics.txt"), 'a') as f:
        f.write(f"exact_without_trigger: {exact_without_trigger}\n")
        f.write(f"exact_with_trigger: {exact_with_trigger}\n")
        f.write(f"metrics_without_trigger: {metrics_without_trigger}\n")
        f.write(f"metrics_with_trigger: {metrics_with_trigger}\n")



def main(args, attack_set, eval_set):

    # Setup for DDP
    utils_ddp.init_distributed_mode(args)  
    cudnn.benchmark = True  

    device = torch.device(args.device)

    model_name, model_type = args.model_name, args.model_type

    batch_size = args.batch_size
    image_size = args.image_size

    trigger = args.trigger
    target_answer = args.target_answer

    #### Dataset #### 
    if utils_ddp.is_main_process():
        print("Creating vqa datasets")
    train_dataset, train_for_eval_dataset, val_dataset = create_dataset(args.dataset, attack_set, eval_set, image_size)

    if args.distributed:
        num_tasks = utils_ddp.get_world_size()
        global_rank = utils_ddp.get_rank()            
        # not train for not shuffle
        samplers = create_sampler([train_dataset], [False], num_tasks, global_rank) + [None, None]
    else:
        samplers = [None, None, None]

    vqa_attack_dataloader, vqa_attack_dataloader_for_eval, vqa_eval_dataloader = create_loader([train_dataset, train_for_eval_dataset, val_dataset],
                                                                                               samplers, batch_size=[batch_size, 1, 1],
                                                                                               num_workers=[4,4,4], is_trains=[False, False, False]) 

    # Pixel attack: alpha (step-size) = epsilon / 255 / max_epochs * weight
    alpha = args.epsilon / 255.0 / args.max_epochs * args.alpha_weight  
    epsilon = args.epsilon / 255.0

    # Patch attack: learning rate 
    lr = args.lr / args.max_epochs  

    if utils_ddp.is_main_process():
        logging.info(f'Pixel attack: {args.epsilon} / 255.0 / {args.max_epochs} * {args.alpha_weight}')
        logging.info(f'Patch attack: {args.lr} / {args.max_epochs}')

    ## samples for single GPU
    local_attack_samples = args.attack_samples // utils_ddp.get_world_size()

    start_time = time.time()

    # select and load model
    if utils_ddp.is_main_process():
        print(f"Loading LAVIS models: {model_name}, model_type: {model_type}...")
    AttackerClass = attacker_mapping.get(model_name)
    if AttackerClass is not None:
        attacker = attackerHelper.__dict__[AttackerClass](model_name, model_type, device)
    else:
        raise ValueError(f"Unknown model name: {model_name}")
    if utils_ddp.is_main_process():
        print(f"Done")

    end_time = time.time()
    elapsed_time = (end_time - start_time) / 60

    if utils_ddp.is_main_process():
        print(f"Model loading time: {elapsed_time:.2f} minutes")
        logging.info(f"Model loading time: {elapsed_time:.2f} minutes")

    if args.distributed:
        attacker.model = torch.nn.parallel.DistributedDataParallel(attacker.model, device_ids=[args.gpu])
        
    # init delta and delta_mask as UAP
    batch_delta, delta_mask = init_uap(args, batch_size, image_size, epsilon, device)
    delta = batch_delta[0]

    batch_delta.requires_grad_()
    v = 0
    for epoch in tqdm(range(args.max_epochs)):

        if args.distributed:
            vqa_attack_dataloader.sampler.set_epoch(epoch)

        cur_epoch = epoch + 1

        metric_logger = utils_ddp.MetricLogger(delimiter="  ")
        metric_logger.add_meter('loss', utils_ddp.SmoothedValue(window_size=1, fmt='{value:.6f}'))
        metric_logger.add_meter('loss_without_trigger', utils_ddp.SmoothedValue(window_size=1, fmt='{value:.4f}'))
        metric_logger.add_meter('loss_with_trigger', utils_ddp.SmoothedValue(window_size=1, fmt='{value:.4f}'))

        for batch_idx, item in enumerate(vqa_attack_dataloader):  # batch
            if batch_size * (batch_idx+1) > local_attack_samples:  # training set
                break
            
            if batch_idx > 0 or epoch > 0:  # Avoid NoneType
                batch_delta.grad.data.zero_()

            img_ori = item['image']  # batch [batch_size, 3, 224, 224]
            img_ori = img_ori.to(device)  

            # Update batch_delta during each epoch
            batch_delta.data = delta.unsqueeze(0).repeat([img_ori.shape[0], 1, 1, 1])
            # SSA 
            noise = 0 
            if args.NOT_SSA:
                if args.patch_attack:
                    img_adv = torch.mul((1-delta_mask.to(device)), img_ori) + batch_delta.to(device) * delta_mask.to(device)
                else:
                    img_adv = img_ori + batch_delta.to(device)
                if utils_ddp.is_main_process():
                    if epoch == 0 and batch_idx == 0:
                        torchvision.utils.save_image(img_adv.detach().cpu(), os.path.join(folder_to_save, 'img_adv0.png'))
                sample_ori, sample_without_trigger, sample_with_trigger = attacker.get_sample(item, img_ori, img_adv, trigger, target_answer)
                # CE loss (without trigger) + CE loss (with trigger)
                loss, loss_without_trigger, loss_with_trigger = attacker.get_loss(sample_ori, sample_without_trigger, sample_with_trigger, 
                                                                args.loss_without_trigger_weight, args.loss_with_trigger_weight, 
                                                                args.loss_type)
                loss.backward()

                metric_logger.update(loss=loss.item())
                metric_logger.update(loss_without_trigger=loss_without_trigger.item())
                metric_logger.update(loss_with_trigger=loss_with_trigger.item())

                if args.distributed:
                    dist.all_reduce(batch_delta.grad, op=dist.ReduceOp.SUM)
                    batch_delta.grad /= dist.get_world_size()

                noise = batch_delta.grad.data.mean(dim=0)

            else:
                for n in range(args.N):  # ensemble
                    ## SSA get idct
                    img_adv = get_img_idct(img_ori, batch_delta, image_size, args.rho, args.sigma, device, patch_attack=args.patch_attack, delta_mask=delta_mask)

                    if utils_ddp.is_main_process():
                        if n == 0 and epoch == 0 and batch_idx == 0:
                            torchvision.utils.save_image(img_adv.detach().cpu(), os.path.join(folder_to_save, 'img_adv0.png'))
                    # Return normalized images, normalized in 'get_sample'
                    sample_ori, sample_without_trigger, sample_with_trigger = attacker.get_sample(item, img_ori, img_adv, trigger, target_answer)
                    # CE loss (without trigger) + CE loss (with trigger)
                    loss, loss_without_trigger, loss_with_trigger = attacker.get_loss(sample_ori, sample_without_trigger, sample_with_trigger, 
                                                                    args.loss_without_trigger_weight, args.loss_with_trigger_weight, 
                                                                    args.loss_type)
                    loss.backward()

                    metric_logger.update(loss=loss.item())
                    metric_logger.update(loss_without_trigger=loss_without_trigger.item())
                    metric_logger.update(loss_with_trigger=loss_with_trigger.item())

                    if args.distributed:
                        dist.all_reduce(batch_delta.grad, op=dist.ReduceOp.SUM)
                        batch_delta.grad /= dist.get_world_size()

                    # Accumulate
                    noise += batch_delta.grad.data.mean(dim=0)  # [3, 224, 224]

                noise = noise / args.N

            ## Momentum
            batch_delta_grad = noise
            if torch.norm(batch_delta_grad, p=1) == 0:
                batch_delta_grad = batch_delta_grad
            else:
                batch_delta_grad = batch_delta_grad / torch.norm(batch_delta_grad, p=1)
            v = args.mu * v + batch_delta_grad
            grad = v
            # Without Momentum: grad_sign = noise.sign()

            if args.pixel_attack:
                delta = delta + alpha * grad.sign()
                delta = torch.clamp(delta, -epsilon, epsilon)
            elif args.patch_attack:
                delta = delta + lr * grad.sign()
                delta = torch.clamp(delta, 0, 1)
            batch_delta.grad.data.zero_()
        
        # gather the stats from all processes
        metric_logger.synchronize_between_processes()
        if utils_ddp.is_main_process():
            print("Averaged stats:", metric_logger.global_avg())  
        train_stats = {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}

        if utils_ddp.is_main_process():               
            # Log statistics
            log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                        'epoch': epoch,
                        }                
            with open(os.path.join(folder_to_save, "log.txt"), "a") as f:
                f.write(json.dumps(log_stats) + "\n")       

            # Save uap and delta_mask at specific epochs
            if cur_epoch % args.store_epoch == 0:
                uap = delta.detach().cpu()
                uap_path = os.path.join(folder_to_save, f"uap_sample{args.attack_samples}_step{cur_epoch}.pth")
                if args.patch_attack:
                    uap = torch.mul(delta_mask[0].cpu(), uap)  
                    mask_path = os.path.join(folder_to_save, f"delta_mask_sample{args.attack_samples}_step{cur_epoch}.pth")
                    torch.save(delta_mask[0].cpu(), mask_path)
                torch.save(uap, uap_path)
                
                # Evaluation
                uap_for_eval = torch.load(uap_path).unsqueeze(0).repeat([1, 1, 1, 1]).to(device)
                eval(args, folder_to_save, uap_for_eval, trigger, target_answer, attacker, vqa_attack_dataloader_for_eval, device, f'uap_epoch{cur_epoch}_output_training_set', delta_mask=delta_mask)
                eval(args, folder_to_save, uap_for_eval, trigger, target_answer, attacker, vqa_eval_dataloader, device, f'uap_epoch{cur_epoch}_output_val_set', delta_mask=delta_mask)

        # Ensure all processes have finished logging and saving before proceeding
        dist.barrier()  


if __name__ == '__main__':
    seedEverything()
    parser = argparse.ArgumentParser()
    parser.add_argument('--device', default='cuda')

    ## Output Config
    parser.add_argument('--dir_path', default='./Anydoor') 
    parser.add_argument('--output', default='output') 

    ## Model Config
    parser.add_argument('--model_name', default='blip2_vicuna_instruct') 
    parser.add_argument('--model_type', default='vicuna7b') 

    # Dataset Config
    parser.add_argument('--dataset', default='coco_vqa', help='coco_vqa or svit') 
    parser.add_argument('--attack_set', default='attack_set', help='attack_set json') 
    parser.add_argument('--eval_set', default='eval_set', help='eval_set json') 

    # Data Config
    parser.add_argument('--batch_size', default=1, type=int)
    parser.add_argument("--attack_samples", default=1, type=int)
    parser.add_argument("--eval_samples", default=50, type=int)
    parser.add_argument("--image_size", default=224, type=int)

    ## Attack Config
    parser.add_argument("--max_epochs", default=5, type=int)
    parser.add_argument("--store_epoch", default=1, type=int)

    parser.add_argument('--trigger', default='abaaba') 
    parser.add_argument('--target_answer', default='error code') 

    # Pixel attack Config
    parser.add_argument('--pixel_attack', action='store_true', help='pixel attack')
    parser.add_argument("--alpha_weight", default=5, type=int)
    parser.add_argument("--epsilon", default=32, type=int)

    # Patch attack Config
    parser.add_argument('--patch_attack', action='store_true', help='patch attack')
    parser.add_argument('--patch_mode', help='one_corner, four_corner, border')
    parser.add_argument("--patch_size", default=64, type=int)
    parser.add_argument('--patch_position', default=None, help='top_left, top_right, bottom_left, bottom_right') 
    parser.add_argument("--lr", default=10, type=float)

    ## SSA Config
    parser.add_argument("--N", type=int, default=20, help="The number of Spectrum Transformations")
    parser.add_argument("--sigma", type=float, default=16.0, help="Std of random noise")
    parser.add_argument("--rho", type=float, default=0.5, help="Tuning factor")

    ## MI Config
    parser.add_argument("--mu", default=0.9, type=float)

    # Loss Config
    # CE loss (without trigger) + CE loss (with trigger)
    parser.add_argument("--loss_without_trigger_weight", default=1.0, type=float)
    parser.add_argument("--loss_with_trigger_weight", default=1.0, type=float)
    parser.add_argument('--loss_type', default=2, type=int,
                        help='1=without trigger, 2=with trigger, 3=both')

    parser.add_argument('--save_adv', action='store_true', help='saving adv images or not')
    parser.add_argument('--clamp', action='store_true', help='')
    parser.add_argument('--NOT_SSA', action='store_true', help='')

    ## For DDP
    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')    
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    parser.add_argument('--distributed', default=True, type=bool)

    args = parser.parse_args()

    # attack_set = f'{args.dir_path}/s_datasets/{args.dataset}_attack_set_{args.attack_samples}.json'
    # eval_set = f'{args.dir_path}/s_datasets/{args.dataset}_eval_set_{args.eval_samples}.json'
    
    attack_set = args.attack_set
    eval_set = args.eval_set
    
    # output dir: args.output -> sub-dir
    base_path = Path(args.dir_path) / args.output / args.model_name / args.dataset

    if args.pixel_attack:
        output_path = base_path / f'loss{args.loss_type}/pixel_attack/ep{args.epsilon}/sample{args.attack_samples}/a{args.alpha_weight}/mu{args.mu}/iter{args.max_epochs}/wo{args.loss_without_trigger_weight}/w{args.loss_with_trigger_weight}'
    elif args.patch_attack:
        if args.patch_mode == 'one_corner':
            output_path = base_path / f'loss{args.loss_type}/patch_attack/{args.patch_mode}_{args.patch_position}/ps{args.patch_size}/sample{args.attack_samples}/mu{args.mu}/iter{args.max_epochs}/wo{args.loss_without_trigger_weight}/w{args.loss_with_trigger_weight}'
        else:
            output_path = base_path / f'loss{args.loss_type}/patch_attack/{args.patch_mode}/sample{args.attack_samples}/ps{args.patch_size}/mu{args.mu}/iter{args.max_epochs}/wo{args.loss_without_trigger_weight}/w{args.loss_with_trigger_weight}'
    folder_to_save = os.path.join(output_path, "output_uap")

    Path(output_path).mkdir(parents=True, exist_ok=True)
    Path(folder_to_save).mkdir(parents=True, exist_ok=True)
    
    log_file = os.path.join(output_path, f"log.log")
    logging.Formatter.converter = customTime
    logging.basicConfig(filename=log_file,
                        filemode='a', 
                        format='%(asctime)s - %(levelname)s - \n %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.INFO)

    yaml.dump(args, open(os.path.join(output_path, 'args.yaml'), 'w'), indent=4)
    if utils_ddp.is_main_process():
        logging.info(args)
        logging.info(f'folder_to_save: {folder_to_save}')
        logging.info(f'attack_set:{attack_set}')
        logging.info(f'eval_set:{eval_set}')
    main(args, attack_set, eval_set)

    logging.info('Done...')