import transformers
import argparse
import os
import logging
import torch
import numpy as np
transformers.logging.set_verbosity(transformers.logging.ERROR)
import sys
import shutil
import time
import random
import wandb
from PIL import Image
from pprint import pformat

from attackers.gbda.utils import *
from attackers.ea.ea import EvolutionSearcher
from attackers.prompt_optimizer import PromptOptimizer
from attackers.promptist import PROMPTIST

from utils.sd_utils import StableDiffusion
from utils.attack_utils import *
from utils.exp_utils import *

from prompts.common import validate_synonyms, negative_prompt_substitute
from prompts.paths import *


def main(args, wandb_table, device):
    ######## load model


    #### load model
    sd = StableDiffusion(args.version, args.loss, device=device)
    text_encoder = sd.text_encoder
    tokenizer = sd.tokenizer


    #### substitutes
    substitutes_for_all_prompts = json_load(args.path, report_error=True)
    ori_prompt = list(substitutes_for_all_prompts.keys())[args.prompt_id]
    substitutes_all = substitutes_for_all_prompts[ori_prompt]

    substitutes = {'prompt': ori_prompt, 'keywords': substitutes_all['keywords']}
    if 'attack' in args.task.lower():
        substitutes['pos'] = substitutes_all['sub']
    else:
        if 'pos' in args.task:
            substitutes['pos'] = substitutes_all['sub']
        
        if 'opp' in args.task:
            substitutes['neg'] = substitutes_all['opp']
        if 'lib' in args.task:
            lib_len = 4
            subs_lib = negative_prompt_substitute(ori_prompt, length=lib_len, version=args.task)
            if 'neg' in substitutes:
                substitutes['neg'].update(subs_lib['sub'])
            else:
                substitutes['neg'] = subs_lib['sub']
        ## add empty string for this task
        ## TODO under dev ## also adding empty string to the positive too, not sure if this works better
        for _type in substitutes:
            if _type in ['prompt', 'keywords']: continue
            for ori_word in substitutes[_type]:
                substitutes[_type][ori_word].insert(0, "")

    if 'attack' in args.task.lower() and args.thresh > 0:
        constraint_fn = get_constraint_fn(args.constraint, tokenizer, text_encoder, ori_prompt)
        substitutes = validate_synonyms(substitutes, args.constraint, args.thresh, tokenizer, text_encoder, constraint_fn)

    train_seed_list = list(range(args.num_seeds_train)) if args.num_seeds_train > 0 else list(range(args.num_seeds))
    test_seed_list = [seed + 10 * args.transfer for seed in range(args.num_seeds)]
    scale = 7.5
    logging.info('ori prompt: {}'.format(ori_prompt))
    logging.info('seeds: {}'.format(test_seed_list))
    logging.info('scale: {}'.format(scale))


    ######## prompt optimization
    s = time.time()
    num_evals = 0

    prompt_opt = PromptOptimizer(text_encoder, tokenizer, substitutes, args)
    temp_scheduler = TempScheduler(*args.temp, args.num_iters)
    ts, avg_t = get_t_schedule(args)


    #### train gumbel iterations
    best_log_coeffs, best_argmax_avg_loss, best_visited = prompt_opt.log_coeffs, 1e20, {}
    
    for i in range(args.num_iters):
        avg_train_loss = 0
        temp = temp_scheduler.get_temp()
        temp_scheduler.step()
        
        for partid in range(args.grad_accum_iters):  ## gradient accumulation
            ## sample mixed embedding
            mixed_embeds_pos, mixed_embeds_neg = prompt_opt.get_mixed_embeds(temp)

            #### forward
            t = ts[i]
            avg_train_loss = sd.inference_grad(ori_prompt, mixed_embeds_pos, mixed_embeds_neg,
                                               args.infer_step, scale, train_seed_list, t, args)

        logging.info('-'*20 + ' optimizer step at iter = {}'.format(i + 1))
        prompt_opt.step()
        prompt_opt.display_log_coeffs()
        
        ## early stop
        early_stop = args.best and (i + 1) % args.best == 0

        if early_stop:
            prompt_pos, prompt_neg = prompt_opt.sample_prompt(argmax=True)
            if (prompt_pos, prompt_neg) in best_visited:
                argmax_avg_loss = best_visited[(prompt_pos, prompt_neg)]
            else:
                argmax_avg_loss, _ = sd.inference(ori_prompt, prompt_pos, args.best_infer_step, scale, train_seed_list, args, negative_prompt=prompt_neg)
                best_visited[(prompt_pos, prompt_neg)] = argmax_avg_loss
            
            num_evals += args.best_infer_step / 50 * 1

            if argmax_avg_loss <= best_argmax_avg_loss:
                best_argmax_avg_loss = argmax_avg_loss
                best_log_coeffs = [log_coeff.clone() for log_coeff in prompt_opt.log_coeffs]
                logging.info('best mean avg loss updated: {:.4f}'.format(best_argmax_avg_loss))

        ## visualization
        if args.wandb:
            message = {"train_loss-seed={}".format(args.seed): abs(avg_train_loss),
                       "max_grad_norm-seed={}".format(args.seed): prompt_opt.get_max_grad_norm(),
                       "t-seed={}".format(args.seed): t}
            wandb.log(message)

    num_evals += (args.num_iters * avg_t / 50)
    num_evals += args.gumbel_samples

    ## register best log_coeffs
    logging.info('RESULT LOGITS')
    prompt_opt.set_log_coeffs(best_log_coeffs)
    prompt_opt.display_log_coeffs()

    search_cost_in_min = (time.time() - s) / 60
    logging.info('Search cost: {:.2f} min, num_evals (GPO + Best + EA): {}'.format(search_cost_in_min, num_evals))


    ######## sample from learned distribution (RS or EA)
    def eval_cand_fn(prompt_pos, prompt_neg):
        avg_loss, image_pil_loss = sd.inference(ori_prompt, prompt_pos, 50, scale, train_seed_list, args, negative_prompt=prompt_neg)
        logging.info('loss: {:.4f}'.format(avg_loss))
        return avg_loss, image_pil_loss

    logging.info('='*20 + ' EA ' + '='*20)
    searcher = EvolutionSearcher(args, eval_cand_fn, prompt_opt, constraint_fn, args.thresh)
    res_list = searcher.search()


    ######## final log
    topk = 5
    #### baselines
    logging.info('EVALUATING BASELINES')
    def eval_prompt(prompt, seed_list, prefix):
        if isinstance(prompt, tuple) or isinstance(prompt, list):
            prompt, neg_prompt = prompt
        else:
            neg_prompt = None
        avg_loss, image_pil_loss = sd.inference(ori_prompt, prompt, 50, scale, seed_list, args, negative_prompt=neg_prompt)
        logging.info('ori loss: {}'.format(avg_loss))
        im = Image.fromarray(np.concatenate([it[0] for it in image_pil_loss], axis=1))
        if args.save_image:
            save_concat_image(args.save_image_path, im, prefix)
        return im, avg_loss
    ## original image
    logging.info('original prompt...')
    ori_im, ori_avg_loss = eval_prompt(ori_prompt, train_seed_list, prefix='ori')
    ## (REMOVED) human engineered prompts
    hep_im, hep_avg_loss = None, None
    ## promptist (train)
    pts_best_avg_loss = None
    if args.task.lower() != 'attack':
        logging.info('promptist...')
        pts_prompts = []
        for cand_idx, pts_prompt in enumerate(PROMPTIST(device=device).generate(ori_prompt, n=topk)):
            pts_avg_loss, image_pil_loss = sd.inference(ori_prompt, pts_prompt, 50, scale, train_seed_list, args)
            
            pts_prompts.append([pts_prompt, pts_avg_loss, image_pil_loss])
        pts_prompts = sorted(pts_prompts, key=lambda x: x[1])
        pts_best_avg_loss = pts_prompts[0][1]

    #### DPO
    success_opt_prompts = []
    for cand_info in res_list:
        opt_prompt = cand_info['cand_prompt']
        avg_loss   = cand_info['avg_loss']
        loss_list  = cand_info['loss_list']
        success_opt_prompts.append([opt_prompt, avg_loss, loss_list])
    success_opt_prompts = success_opt_prompts[:topk]


    logging.info('RESULT IMAGE')
    with torch.no_grad():
        for cand_idx, (opt_prompt, avg_loss, image_pil_loss) in enumerate(success_opt_prompts):
            logging.info(f'successful opt prompt: {opt_prompt} - {abs(avg_loss)}')
            if opt_prompt[1] is None:
                opt_prompt = (opt_prompt[0], '')

            ## train
            opt_im = Image.fromarray(np.concatenate([it[0] for it in image_pil_loss], axis=1))
            if args.task.lower() != 'attack':
                pts_prompt, pts_avg_loss, pts_image_pil_loss = pts_prompts[cand_idx]
                pts_im = Image.fromarray(np.concatenate([it[0] for it in pts_image_pil_loss], axis=1))

            if args.wandb and cand_idx < 10: # save top10 images on wandb
                if args.task.lower() == 'attack':
                    opt_row = [
                        {'train':{'opt': wandb.Image(opt_im), 'ori': wandb.Image(ori_im)}},
                        {'train': {'opt': abs(avg_loss), 'ori': abs(ori_avg_loss)}},
                        {'opt': '\n'.join(list(opt_prompt)), 'ori': ori_prompt, 'keywords': substitutes['keywords']}
                    ]
                else:
                    opt_row = [
                        {'train':{'opt': wandb.Image(opt_im), 'ori': wandb.Image(ori_im), 'pts': wandb.Image(pts_im)}},
                        {'train': {'opt': abs(avg_loss), 'pts': abs(pts_avg_loss), 'ori': abs(ori_avg_loss)}},
                        {'opt': '\n'.join(list(opt_prompt)), 'pts': pts_prompt, 'ori': ori_prompt, 'keywords': substitutes['keywords']}
                    ]

                if hep_im is not None:  ## add hep
                    opt_row[0]['train']['hep'] = wandb.Image(hep_im)
                    opt_row[1]['train']['hep'] = abs(hep_avg_loss)
                wandb_table.add_data(*opt_row)

            if args.save_image and cand_idx < 5:  # save top5 images locally
                save_concat_image(args.save_image_path, opt_im, f'dpo-{cand_idx}')
                if args.task.lower() != 'attack':
                    save_concat_image(args.save_image_path, pts_im, f'pts-{cand_idx}')

    all_avg_losses = [item[1] for item in success_opt_prompts]
    top10_avg_losses = list(sorted(all_avg_losses))[:10]
    for i, loss in enumerate(top10_avg_losses):
        logging.info('top {}/10 avg losses {}'.format(i, loss))
    logging.info('ori avg loss: {}'.format(ori_avg_loss))

    result = {
        'ori_prompt': ori_prompt,
        'ori_avg_loss': ori_avg_loss,
        'avg_losses': all_avg_losses,
        'top10_avg_losses': top10_avg_losses,
        'best_avg_loss': top10_avg_losses[0],
        # 'cand_history': cand_history,
        'num_evals': num_evals
    }
    if args.task.lower() != 'attack':  ## has pts baseline
        result['pts_best_avg_loss'] = pts_best_avg_loss

    return result







if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Hybrid Optimizer.")

    #### management
    parser.add_argument('--group', type=str, default='none', help='exps in the same group are all saved to [group] folder')
    parser.add_argument('--save', type=str, default='cmd', help='saving directory / expid')
    parser.add_argument('--tag', type=str, default='none', help='extra tag')
    parser.add_argument('--gpu', type=str, default='auto')
    parser.add_argument('--gentle', type=float, default=0, help='queue for empty gpu')
    #### visualization
    parser.add_argument('--wandb', type=int, default=0, help='use wandb')
    parser.add_argument('--save_image', type=int, default=0, help='whether to save generated image')
    parser.add_argument('--log_argmax', type=int, default=0, help="whether to log argmax loss")

    #### task
    parser.add_argument("--version", default='v1-4', type=str, help="Stable Diffusion Version")
    parser.add_argument('--reverse', type=float, default=-1, help='improve results')
    parser.add_argument('--task', type=str, default='attack')
    #### constraints
    parser.add_argument('--constraint', type=str, default='none')
    parser.add_argument('--thresh', type=float, default=0,
                        help='threshold for constraint optimization')
    #### evaluation
    parser.add_argument('--num_exps', type=int, default=1, help='number of runs')
    parser.add_argument('--path', type=str, default='substitutes-dev', help='substitutes json file path')
    parser.add_argument('--prompt_id', type=int, default=0)
    parser.add_argument('--num_seeds', type=int, default=3, help='testing seed')
    parser.add_argument('--transfer', type=int, default=0, help='if transfer, add 10 to each num_seeds')

    #### gradient-based
    parser.add_argument("--num_iters", default=20, type=int, help="number of iterations in search")
    parser.add_argument("--batch_size", default=1, type=int,
        help="batch size for gumbel-softmax samples, changing this to >1 will cause error")
    parser.add_argument("--init_coeff", default=1, type=float)
    parser.add_argument("--max_coeff", default=3, type=float)
    parser.add_argument("--min_coeff", default=0, type=float)
    parser.add_argument("--lr", default=0.1, type=float)
    parser.add_argument("--gumbel_samples", default=100, type=int, help="number of gumbel samples")
    parser.add_argument('--mixer', type=str, default='gumbel')
    parser.add_argument('--loss', type=str, default='spherical',
                        choices=['spherical', 'cos', 'hps', 'hps2'],
                        help='spherical distance loss v.s. cosine loss')
    parser.add_argument('--use_cutouts', type=int, default=0)
    parser.add_argument('--grad_accum_iters', type=int, default=1)
    parser.add_argument("--clip_grad", default=0.025, type=float)
    ## diffusion model
    parser.add_argument('--t', type=str, default='step-15-2', help="t scheduler")
    parser.add_argument('--infer_step', type=int, default=50, help="inference step for diffusion")
    parser.add_argument('--num_seeds_train', type=int, default=1, help='-1: random sampling from num_seeds')
    #### ablation
    parser.add_argument("--temp", default='1,1', type=str, help='min and max temperature for gumbel')
    parser.add_argument('--best', type=int, default=0, help=">0: return best log_coeffs, evaled at every X iters")
    parser.add_argument("--init_method", default='default', type=str, help="method for initial log coefficients")
    parser.add_argument('--sample_temp', type=float, default=0, help='temperature for sampling, 0 -> same as min_temp')
    parser.add_argument('--betas', type=str, default="(0.5,0.999)", help="beta for Optimizer")
    parser.add_argument('--opt', type=str, default='rmsprop', help="optimizer")
    #### post ea search
    ## ea args (different from pure ea search)
    parser.add_argument('--explore', type=float, default=0.0, help="exploration (uniform sampling) ratio")
    ## ea inherit
    parser.add_argument('--select_num', type=int, default=10)
    parser.add_argument('--population_num', type=int, default=20)
    parser.add_argument('--m_prob', type=float, default=0.1)
    parser.add_argument('--crossover_num', type=int, default=10)
    parser.add_argument('--mutation_num', type=int, default=10)
    parser.add_argument('--first_argmax', type=int, default=1, help='sample argmax as first')

    args = parser.parse_args()


    #### args augment
    args.gpu = str(pick_gpu_lowest_memory(args.gentle)) if args.gpu == 'auto' else args.gpu
    ## TODO under dev ##
    # args.device = f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu'
    args.device = 'cuda:0'
    if args.task.lower() == 'attack':
        args.gs_name = 'Attack'
    else:
        args.gs_name = 'Improve'
        args.reverse = 1
    if 'debug' in args.tag: args.group = 'debug'

    ## output dir
    script_name = args.save
    exp_id = '{}'.format(script_name)
    exp_id += f"_task={args.task}-{args.path.replace('/usps-loss-', '-').replace('.json', '').replace('-{}', '')}"
    exp_id += f'_coeff={int(args.init_coeff)}{abs(args.min_coeff)}{args.max_coeff}'
    if args.constraint != 'none' and args.thresh > 0: exp_id += f'_cst={args.constraint}-{args.thresh}'
    if args.t != 'rand-15-25': exp_id += f'_t={args.t}'
    if args.transfer: exp_id += f'_transfer'
    if args.best: exp_id += f'_best={args.best}'
    if args.loss != 'spherical': exp_id += f'_{args.loss}'
    if args.tag and args.tag != 'none': exp_id += f'_tag={args.tag}'
    exp_id += f'_bgt={args.gumbel_samples}'
    if 'debug' in args.tag: exp_id = args.tag
    args.save = os.path.join('experiments/', f'{args.group}/', exp_id, f'{args.prompt_id}')

    ## override path
    if os.path.exists(args.save):
        if 'debug' in args.tag or input('{} exists, override? [y/n]'.format(args.save)) == 'y': shutil.rmtree(args.save)
        else: exit()
    create_exp_dir(args.save, run_script='./exp_scripts/{}'.format(script_name + '.sh'))
    ## output files
    args.save_image_path = os.path.join(args.save, 'gen_images')
    args.save_plot_path = os.path.join(args.save, 'plots')
    args.save_json_path = os.path.join(args.save.replace(f'/{args.prompt_id}', ''), 'all_results.json')
    args.save_cand_path = os.path.join(args.save.replace(f'/{args.prompt_id}', ''), 'cand_history.json')
    os.mkdir(args.save_image_path)
    os.mkdir(args.save_plot_path)
    # logging
    log_format = '%(message)s'
    # logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='[%m/%d %I:%M:%S %p]')
    logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format)
    log_file = 'log.txt'
    log_path = os.path.join(args.save, log_file)
    if os.path.exists(log_path) and input(f'log: {log_file} exists, override? [y/n]'.format(exp_id)) != 'y':
        exit(0)
    fh = logging.FileHandler(log_path, mode='w')
    fh.setFormatter(logging.Formatter(log_format))
    logging.getLogger().addHandler(fh)
    logging.info('\n================== Args ==================\n')
    logging.info(pformat(vars(args)))


    #### initialize wandb
    if args.wandb:
        wandb.init(
            project=f"{args.gs_name}-{args.group}",
            # track hyperparameters and run metadata
            config=args.__dict__,
            name=exp_id + f'_{args.prompt_id}',
            dir=args.save
        )
    
        #### wandb image table
        columns = ["im", "L", "P"]
        wandb_table = wandb.Table(columns=columns)
    else:
        wandb_table = None

    ## args transform
    args.betas = eval(args.betas)
    args.infer_step = int(args.infer_step)
    args.temp = sorted(list(eval(args.temp)))
    if args.sample_temp == 0:
        args.sample_temp = args.temp[0]
    if 'data' not in args.path[:6]: args.path = os.path.join('data', args.path)
    if 'nfs' not in args.path: args.path = os.path.join(PROJECT_DIR, args.path)
    if '.json' not in args.path: args.path += '.json'


    all_results = []
    for i in range(args.num_exps):
        logging.info('\n' * 10 + '='*40 + '\n' + f'Exp: {i}' + '\n' + '='*40)
        deterministic_mode(seed=i)
        args.seed = i

        result = main(args, wandb_table, args.device)

        all_results.append(result)
        if args.wandb: wandb.log({'ori-loss={:.4f}-seed={}'.format(result['ori_avg_loss'], args.seed): wandb_table})
    torch.save(all_results, os.path.join(args.save, 'all_results.pth'))


    #### final log
    ## save result to common json
    new_entry = result_to_json(all_results)
    json_result = json_load(args.save_json_path)
    json_result[str(args.prompt_id)] = new_entry
    json_result_sorted = {key: json_result[key] for key in sorted(json_result)}
    json_save(json_result_sorted, args.save_json_path)
    ## print
    (avg_best, std_best), (avg_top10, std_top10) = compute_exp_stats(all_results)
    logging.info('Best:  {:.4f} \u00B1 {:.4f}'.format(avg_best, std_best))
    logging.info('Top10: {:.4f} \u00B1 {:.4f}'.format(avg_top10, std_top10))
