# https://github.com/openai/CLIP/issues/83

import argparse
import os
import sys

import ruamel.yaml as yaml
import numpy as np
import random
import time
import datetime
import json
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.utils.data import DataLoader
import torch.optim as optim

from torchvision import transforms
from PIL import Image

import utils

from easydict import EasyDict as edict
from models.get_model import load_model

# Codes imported from https://github.com/salesforce/ALBEF/blob/main/Retrieval.py
from dataset_with_aug import create_dataset, create_sampler, create_loader
from scheduler import create_scheduler
from optim import create_optimizer


# SGA
sys.path.append("SGA")
from attacker import (
    SGAttacker as SGAttacker,
    ImageAttacker as SGA_ImageAttacker,
    TextAttacker as SGA_TextAttacker,
)

# ACMMM2022: https://github.com/adversarial-for-goodness/Co-Attack/tree/main
# attack image_embed and text_embed
from co_attack_modified import ImageAttacker as PGD_ImageAttacker
from co_attack_modified import BertAttack, BertAttackFusion
from co_attack_modified import MultiModalAttacker as CoAttacker


def get_attacker(args, model, ref_model, tokenizer):
    images_normalize = transforms.Normalize(
        (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
    )
    if args.attack == "SGA":
        img_attacker = SGA_ImageAttacker(
            images_normalize, eps=args.epsilon / 255, steps=10, step_size=0.5 / 255
        )
        txt_attacker = SGA_TextAttacker(
            ref_model,
            tokenizer,
            cls=False,
            max_length=30,
            number_perturbation=1,
            topk=10,
            threshold_pred_score=0.3,
        )
        attacker = SGAttacker(model, img_attacker, txt_attacker)
    elif args.attack in ["BERT", "PGD", "Sep-Attack", "Co-Attack", "Clean"]:
        image_attacker = PGD_ImageAttacker(
            args.epsilon / 255.0,
            preprocess=images_normalize,
            bounding=(0, 1),
            cls=args.cls,
        )
        if args.attack_fused_emb:
            text_attacker = BertAttackFusion(ref_model, tokenizer, cls=args.cls)
        else:
            text_attacker = BertAttack(ref_model, tokenizer, cls=args.cls)
        attacker = CoAttacker(
            model, image_attacker, text_attacker, tokenizer, cls=args.cls
        )
    return attacker


def attack_batch(attacker, images, texts, device):
    if args.attack == "SGA":
        txt2img = torch.arange(len(images)).to(device)
        scales = [0.5,0.75,1.25,1.5]
        adv_images, adv_texts = attacker.attack(
            images, texts, txt2img, device=device, max_lemgth=30, scales=scales
        )
    elif args.attack in ["BERT", "PGD", "Sep-Attack", "Co-Attack"]:
        if args.attack == "BERT":
            adv_mode = 1
        elif args.attack == "PGD":
            adv_mode = 2
        elif args.attack == "Sep-Attack":
            adv_mode = 3
        elif args.attack == "Co-Attack":
            adv_mode = 4
        elif args.attack == "Clean":
            adv_mode = 0
        else:
            raise ValueError(f"Invalid attack mode: {args.attack}")
        if args.attack_fused_emb:
            # Note: CoAttack for fused embeddings requires paired input (image, text).
            B = len(images)
            adv_images, adv_texts = attacker.run(
                images,
                texts,
                adv=adv_mode,
                num_iters=config["num_iters"],
                alpha=args.alpha,
            )
            # print(len(adv_images), len(adv_texts))
        else:
            # Note: CoAttack for fused embeddings requires paired input (image, text).
            max_length = 77 if "CLIP" in args.model else 1e3
            adv_images, adv_texts = attacker.run_before_fusion(
                images,
                texts,
                adv=adv_mode,
                num_iters=config["num_iters"],
                alpha=args.alpha,
                max_length=max_length,
            )
    else:
        raise ValueError(f"Invalid attack mode: {args.attack}")
    return adv_images, adv_texts


def train(model, data_loader, optimizer, tokenizer, epoch, warmup_steps, device, scheduler, config, attacker=None):
    # train
    model.train()  
    
    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
    metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
    header = 'Train Epoch: [{}]'.format(epoch)
    print_freq = 50
    step_size = 100
    warmup_iterations = warmup_steps*step_size  
    
    for i,(image, image_aug, text, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        image = image.to(device,non_blocking=True)   
        image_aug = image_aug.to(device,non_blocking=True)
        idx = idx.to(device,non_blocking=True)   

        if args.attack is not None:
            image, text = attack_batch(attacker, image, text, device)
            image.detach_()

        text_input = tokenizer(text, padding='longest', max_length=30, return_tensors="pt").to(device)  
            
        if epoch>0 or not config['warm_up']:
            alpha = config['alpha']
        else:
            alpha = config['alpha']*min(1,i/len(data_loader))

        loss_ita, loss_itm = model(image, image_aug, text_input,alpha=alpha, idx=idx)                  
        loss = loss_ita + loss_itm
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()    
        
        metric_logger.update(loss_itm=loss_itm.item())
        metric_logger.update(loss_ita=loss_ita.item())
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])
        if epoch==0 and i%step_size==0 and i<=warmup_iterations: 
            scheduler.step(i//step_size)         
        
    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger.global_avg())     
    return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}  



@torch.no_grad()
def evaluation(model, data_loader, tokenizer, device, config):
    # test
    model.eval() 
    
    metric_logger = utils.MetricLogger(delimiter="  ")
    header = 'Evaluation:'    
    
    print('Computing features for evaluation...')
    start_time = time.time()  

    texts = data_loader.dataset.text   
    num_text = len(texts)
    text_bs = 256
    text_feats = []
    text_embeds = []  
    text_atts = []
    for i in range(0, num_text, text_bs):
        text = texts[i: min(num_text, i+text_bs)]
        text_input = tokenizer(text, padding='max_length', truncation=True, max_length=30, return_tensors="pt").to(device) 
        text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')  
        text_feat = text_output.last_hidden_state
        text_embed = F.normalize(model.text_proj(text_feat[:,0,:]))
        text_embeds.append(text_embed)   
        text_feats.append(text_feat)
        text_atts.append(text_input.attention_mask)
    text_embeds = torch.cat(text_embeds,dim=0)
    text_feats = torch.cat(text_feats,dim=0)
    text_atts = torch.cat(text_atts,dim=0)
    
    image_feats = []
    image_embeds = []
    for image, img_id in data_loader: 
        image = image.to(device) 
        image_feat = model.visual_encoder(image)        
        image_embed = model.vision_proj(image_feat[:,0,:])            
        image_embed = F.normalize(image_embed,dim=-1)      
        
        image_feats.append(image_feat)
        image_embeds.append(image_embed)
     
    image_feats = torch.cat(image_feats,dim=0)
    image_embeds = torch.cat(image_embeds,dim=0)
    
    sims_matrix = image_embeds @ text_embeds.t()
    score_matrix_i2t = torch.full((len(data_loader.dataset.image),len(texts)),-100.0).to(device)
    
    num_tasks = utils.get_world_size()
    rank = utils.get_rank() 
    step = sims_matrix.size(0)//num_tasks + 1
    start = rank*step
    end = min(sims_matrix.size(0),start+step)

    for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): 
        topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)

        encoder_output = image_feats[start+i].repeat(config['k_test'],1,1)
        encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
        output = model.text_encoder(encoder_embeds = text_feats[topk_idx], 
                                    attention_mask = text_atts[topk_idx],
                                    encoder_hidden_states = encoder_output,
                                    encoder_attention_mask = encoder_att,                             
                                    return_dict = True,
                                    mode = 'fusion'
                                   )
        score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
        score_matrix_i2t[start+i,topk_idx] = score
        
    sims_matrix = sims_matrix.t()
    score_matrix_t2i = torch.full((len(texts),len(data_loader.dataset.image)),-100.0).to(device)
    
    step = sims_matrix.size(0)//num_tasks + 1
    start = rank*step
    end = min(sims_matrix.size(0),start+step)    
    
    for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)): 
        
        topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
        encoder_output = image_feats[topk_idx]
        encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
        output = model.text_encoder(encoder_embeds = text_feats[start+i].repeat(config['k_test'],1,1), 
                                    attention_mask = text_atts[start+i].repeat(config['k_test'],1),
                                    encoder_hidden_states = encoder_output,
                                    encoder_attention_mask = encoder_att,                             
                                    return_dict = True,
                                    mode = 'fusion'
                                   )
        score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
        score_matrix_t2i[start+i,topk_idx] = score

    if args.distributed:
        dist.barrier()   
        torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM) 
        torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM)        
        
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Evaluation time {}'.format(total_time_str)) 

    return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()


            
@torch.no_grad()
def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
    
    #Images->Text 
    ranks = np.zeros(scores_i2t.shape[0])
    for index,score in enumerate(scores_i2t):
        inds = np.argsort(score)[::-1]
        # Score
        rank = 1e20
        for i in img2txt[index]:
            tmp = np.where(inds == i)[0][0]
            if tmp < rank:
                rank = tmp
        ranks[index] = rank

    # Compute metrics
    tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
  
    #Text->Images 
    ranks = np.zeros(scores_t2i.shape[0])
    
    for index,score in enumerate(scores_t2i):
        inds = np.argsort(score)[::-1]
        ranks[index] = np.where(inds == txt2img[index])[0][0]

    # Compute metrics
    ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)        

    tr_mean = (tr1 + tr5 + tr10) / 3
    ir_mean = (ir1 + ir5 + ir10) / 3
    r_mean = (tr_mean + ir_mean) / 2

    eval_result =  {'txt_r1': tr1,
                    'txt_r5': tr5,
                    'txt_r10': tr10,
                    'txt_r_mean': tr_mean,
                    'img_r1': ir1,
                    'img_r5': ir5,
                    'img_r10': ir10,
                    'img_r_mean': ir_mean,
                    'r_mean': r_mean}
    return eval_result




def main(args, config):
    utils.init_distributed_mode(args)    
    
    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + utils.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    cudnn.benchmark = True

    # load data
    train_dataset, val_dataset, test_dataset = create_dataset('re', config)
    if args.distributed:
        num_tasks = utils.get_world_size()
        global_rank = utils.get_rank()            
        samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
    else:
        samplers = [None, None, None]
    
    train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
                                                          batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2,
                                                          num_workers=[4,4,4],
                                                          is_trains=[True, False, False], 
                                                          collate_fns=[None,None,None])   
       

    # load model
    print("Loading model")
    model, ref_model, tokenizer = load_model(config, args.model, args.ckpt, args.text_encoder, device=device)

    model = model.to(device)   
    ref_model = ref_model.to(device)

    if args.model in ["TCL"]:
        arg_opt = utils.AttrDict(config['optimizer'])
        optimizer = create_optimizer(arg_opt, model)
        arg_sche = utils.AttrDict(config['schedular'])
        lr_scheduler, _ = create_scheduler(arg_sche, optimizer)  
    elif args.model in ["TCL_PT"]:
        model.wrap_vision_encoder_with_prompter(prompt_config)
        model.to(device)

        parameters = [{"params": model.visual_encoder.prompt_proj.parameters()}] 
        parameters += [{"params": model.visual_encoder.prompt_embeddings}]
        if prompt_config.DEEP:
            parameters += [{"params": model.visual_encoder.deep_prompt_embeddings}]

        # requires_grad = False for other params
        for name, param in model.named_parameters():
            if 'prompt' not in name:
                param.requires_grad = False
        # print trainable parameters
        for name, param in model.named_parameters():
            if param.requires_grad:
                print("----> trainable: ", name)

        arg_opt = utils.AttrDict(prompt_config['optimizer'])
        optimizer = optim.SGD(parameters, nesterov=True, weight_decay=arg_opt.weight_decay, lr=arg_opt.lr, momentum=arg_opt.momentum)
        arg_sche = utils.AttrDict(prompt_config['schedular'])
        lr_scheduler, _ = create_scheduler(arg_sche, optimizer)  
    else:
        raise ValueError("Model not supported: ", args.model)

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module   

    # define attacker
    attacker = None
    if args.attack is not None:
        attacker = get_attacker(args, model_without_ddp, ref_model, tokenizer)

    max_epoch = config['schedular']['epochs']
    warmup_steps = config['schedular']['warmup_epochs']
    best = 0
    best_epoch = 0

    print("Start training")
    start_time = time.time()    
    for epoch in range(0, max_epoch):
        if not args.evaluate:
            if args.distributed:
                train_loader.sampler.set_epoch(epoch)
            train_stats = train(model, train_loader, optimizer, tokenizer, epoch, warmup_steps, device, lr_scheduler, config, attacker=attacker)  
            
        score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, tokenizer, device, config)
        score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, tokenizer, device, config)
    
        if utils.is_main_process():  
      
            val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt)  
            print(val_result)
            test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt)    
            print(test_result)
            
            if args.evaluate:                
                log_stats = {**{f'val_{k}': v for k, v in val_result.items()},
                             **{f'test_{k}': v for k, v in test_result.items()},                  
                             'epoch': epoch,
                            }
                with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
                    f.write(json.dumps(log_stats) + "\n")     
            else:
                log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                             **{f'val_{k}': v for k, v in val_result.items()},
                             **{f'test_{k}': v for k, v in test_result.items()},                  
                             'epoch': epoch,
                            }
                with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
                    f.write(json.dumps(log_stats) + "\n")   
                    
                save_obj = {
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lr_scheduler': lr_scheduler.state_dict(),
                    'config': config,
                    'epoch': epoch,
                }
                torch.save(save_obj, os.path.join(args.output_dir, f'checkpoint_ep{epoch}.pth'))  
                if val_result['r_mean']>best:
                    torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))  
                    best = val_result['r_mean']    
                    best_epoch = epoch
                    print("best epoch: %d"%best_epoch)
                    
        if args.evaluate: 
            break
           
        lr_scheduler.step(epoch+warmup_steps+1)  
        dist.barrier()     
        torch.cuda.empty_cache()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str)) 

    if utils.is_main_process():   
        with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
            f.write("best epoch: %d"%best_epoch)               
        print("best epoch: %d"%best_epoch)

    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", default="./configs/Retrieval_flickr_train.yaml")
    parser.add_argument("--seed", default=42, type=int)

    parser.add_argument("--model", default="ViT-B/16", type=str) # model architecture
    parser.add_argument("--model_name", default="CLIP-vit-b16", type=str) # id for the model
    parser.add_argument("--text_encoder", default="bert-base-uncased", type=str)
    parser.add_argument("--ckpt", default=None, type=str)

    # training config
    parser.add_argument('--evaluate', action='store_true')
    parser.add_argument('--device', default='cuda')
    parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')    
    parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
    parser.add_argument('--distributed', default=False, type=bool)
    
    # adversarial training config
    parser.add_argument(
        "--attack",
        default=None,
        type=str,
        choices=["SGA", "Co-Attack", "Sep-Attack", "PGD", "BERT", "Clean"],
    )
    parser.add_argument("--attack_fused_emb", default=False, type=bool)
    parser.add_argument("--cls", default=False, type=bool)
    parser.add_argument("--output_dir", default="../train_results", type=str)
    parser.add_argument("--epsilon", default=2.0, type=float)
    parser.add_argument("--alpha", default=3.0, type=float)

    # prompt tuning
    parser.add_argument("--prompt_config", default=None, type=str)

    args = parser.parse_args()

    assert args.model in ["TCL", "TCL_PT"]

    config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
    config = edict(config)
    
    if args.prompt_config is not None:
        prompt_tuning_config = yaml.load(open(args.prompt_config, "r"), Loader=yaml.Loader)
        prompt_config = edict(prompt_tuning_config)["PROMPT"]

    # create output directory
    DIR_NAME = args.model_name
    if args.attack is not None:
        DIR_NAME = f"{DIR_NAME}_{args.attack}"
    args.output_dir = os.path.join(args.output_dir, DIR_NAME)
    os.makedirs(args.output_dir, exist_ok=True)
    print("Output directory:", args.output_dir)

    # log print
    sys.stdout = utils.Tee(sys.stdout, open(os.path.join(args.output_dir, "out.txt"), "w"))
    
    # save args
    with open(os.path.join(args.output_dir, "args.txt"), "w") as f:
        f.write(str(args))
    # save config
    with open(os.path.join(args.output_dir, "config.txt"), "w") as f:
        f.write(str(config))
    if args.prompt_config is not None:
        with open(os.path.join(args.output_dir, "prompt_config.txt"), "w") as f:
            f.write(str(prompt_tuning_config))
            
    main(args, config)
