import numpy as np
import torch
from diffusers import AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import UNet2DConditionModel, DDIMScheduler
from tqdm.auto import tqdm
from dataset import load_pokemon_datasets, load_laion5_dataset, load_coco_datasets, load_flickr_datasets
from sklearn import metrics
from filter import Fourier_filter


class Flag(object):
    pass


device = "cuda" if torch.cuda.is_available() else "cpu"
flags = Flag
unet = None
tokenizer = None
text_encoder = None
scheduler = None
vae = None


def init_model(diff_path):
    global unet, tokenizer, text_encoder, scheduler, vae

    vae = AutoencoderKL.from_pretrained(
        diff_path, subfolder='vae', use_auth_token=True)

    tokenizer = CLIPTokenizer.from_pretrained(diff_path, subfolder="tokenizer")
    text_encoder = CLIPTextModel.from_pretrained(diff_path, subfolder="text_encoder")

    unet = UNet2DConditionModel.from_pretrained(diff_path, subfolder='unet')
    # unet = UNet2DConditionModel.from_pretrained("./checkpoint/sd1.4_flickr/checkpoint-20000/unet", subfolder='unet')

    scheduler = DDIMScheduler.from_pretrained(diff_path, subfolder="scheduler")

    vae = vae.to(device)
    vae.eval()
    text_encoder = text_encoder.to(device)
    unet.eval()
    unet = unet.to(device)
    unet.eval()
    print('all model loaded.')


@torch.no_grad()
def extract(v, t, x_shape):
    """
    Extract some coefficients at specified timesteps, then reshape to
    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    out = torch.gather(v, index=t, dim=0).float()
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))


@torch.no_grad()
def ddim_singlestep(model, FLAGS, x, t_c, t_target, requires_grad=False, device='cuda', encoder_hidden_states=None):
    # global CNT
    if encoder_hidden_states == None:
        print('Error! encoder_hidden_states==None')

    x = x.to(device)
    t_c = x.new_ones([x.shape[0], ], dtype=torch.long) * (t_c)
    t_target = x.new_ones([x.shape[0], ], dtype=torch.long) * (t_target)

    betas = scheduler.betas.double().to(device)
    alphas = 1. - betas
    alphas = torch.cumprod(alphas, dim=0)
    alphas_t_c = extract(alphas, t=t_c, x_shape=x.shape)
    alphas_t_target = extract(alphas, t=t_target, x_shape=x.shape)

    if requires_grad:
        epsilon = model(x, t_c, encoder_hidden_states).sample
        # CNT += 1
        # print('ddim_singlestep CNT', CNT)
    else:
        with torch.no_grad():
            epsilon = model(x, t_c, encoder_hidden_states).sample
            # CNT+=1
            # print('ddim_singlestep CNT',CNT)

    pred_x_0 = (x - ((1 - alphas_t_c).sqrt() * epsilon)) / (alphas_t_c.sqrt())
    x_t_target = alphas_t_target.sqrt() * pred_x_0 \
                 + (1 - alphas_t_target).sqrt() * epsilon

    return {
        'x_t_target': x_t_target,
        'epsilon': epsilon
    }


@torch.no_grad()
def ddim_multistep(model, FLAGS, x, t_c, target_steps, clip=False, device='cuda', requires_grad=False,
                   encoder_hidden_states=None):
    for idx, t_target in enumerate(target_steps):
        result = ddim_singlestep(model, FLAGS, x, t_c, t_target, requires_grad=requires_grad, device=device,
                                 encoder_hidden_states=encoder_hidden_states)
        x = result['x_t_target']
        t_c = t_target

    if clip:
        result['x_t_target'] = torch.clip(result['x_t_target'], -1, 1)

    return result


@torch.no_grad()
def att_measure(diffusion, sample, metric, device='cuda'):
    diffusion = diffusion.to(device).float()
    sample = sample.to(device).float()

    if len(diffusion.shape) == 5:
        num_timestep = diffusion.size(0)
        diffusion = diffusion.permute(1, 0, 2, 3, 4).reshape(-1, num_timestep * 3, 32, 32)
        sample = sample.permute(1, 0, 2, 3, 4).reshape(-1, num_timestep * 3, 32, 32)

    if metric == 'l2':
        score = ((diffusion - sample) ** 2).flatten(1).sum(dim=-1)
    elif isinstance(metric, int):
        score = (torch.abs(diffusion - sample) ** metric).flatten(1).sum(dim=-1)
    else:
        raise NotImplementedError
    return score


@torch.no_grad()
def sec_mi(model, batch, vae, text_encoder, device, x_sec_list, x_sec_recon_list, Filter, threshold, scale, scores, highs):
    target_steps = list(range(0, flags.t_sec + flags.timestep, flags.timestep))[1:]  # 10 20 ... 90 100
    starttmp = flags.t_sec

    batch["pixel_values"] = batch["pixel_values"].to(device)
    latents = vae.encode(batch["pixel_values"].to(torch.float32)).latent_dist.sample()
    latents = latents * vae.config.scaling_factor
    x = latents
    embd_cond = text_encoder(batch["input_ids"].to(device))[0]

    x_sec = ddim_multistep(model, flags, x, t_c=0, target_steps=target_steps, device=device,
                           encoder_hidden_states=embd_cond)
    x_sec = x_sec['x_t_target']  #

    endtmp = starttmp + flags.stpsnumi

    forw_stps = list(range(starttmp, endtmp + 1))
    back_stps = list(reversed(list(range(starttmp, endtmp + 1))))

    embd = text_encoder(batch["input_ids"].to(device))[0]

    assert forw_stps[0] == 100 and forw_stps[1] == 101

    x_sec_forw = ddim_singlestep(model, flags, x_sec,
                                 t_c=forw_stps[0], t_target=forw_stps[1],
                                 device=device, encoder_hidden_states=embd)


    x_sec_recon = ddim_singlestep(model, flags, x_sec_forw['x_t_target'],
                                  t_c=back_stps[0], t_target=back_stps[1],
                                  device=device, encoder_hidden_states=embd)
    x_sec_recon = x_sec_recon['x_t_target']

    if Filter == 1:
        x_sec = Fourier_filter(x_sec, threshold=threshold, scale=scale)
        x_sec_recon = Fourier_filter(x_sec_recon, threshold=threshold, scale=scale)

    x_sec_list.append(x_sec)
    x_sec_recon_list.append(x_sec_recon)


@torch.no_grad()
def prox_mi(model, batch, vae, text_encoder, device, x_sec_list, x_sec_recon_list, Filter, threshold, scale, scores, highs):
    # target_steps = list(range(0, flags.t_sec + flags.timestep, flags.timestep))[1:]  # 10 20 ... 90 100
    # starttmp = flags.t_sec

    batch["pixel_values"] = batch["pixel_values"].to(device)
    latents = vae.encode(batch["pixel_values"].to(torch.float32)).latent_dist.sample()
    latents = latents * vae.config.scaling_factor
    x = latents

    embd_cond = text_encoder(batch["input_ids"].to(device))[0]

    def prox_loss(model, flags, x, device, encoder_hidden_states, t=500):
        x = x.to(device)
        t = x.new_ones([x.shape[0], ], dtype=torch.long) * (t)
        betas = scheduler.betas.double().to(device)
        alphas = 1. - betas  ### α
        alphas = torch.cumprod(alphas, dim=0)  ### α
        alphas_t = extract(alphas, t=t, x_shape=x.shape)  ### α

        eps = model(x, 0, encoder_hidden_states).sample
        eps_pred = model(alphas_t.sqrt() * x + (1 - alphas_t).sqrt() * eps, t, encoder_hidden_states).sample

        eps_pred = (alphas_t.sqrt() * x + (1 - alphas_t).sqrt() * eps - (1 - alphas_t).sqrt() * eps_pred) / alphas_t.sqrt()
        eps = x
        return eps, eps_pred

    eps, eps_pred = prox_loss(model, flags, x, device, embd_cond, t=500)

    if Filter == 1:
        eps = Fourier_filter(eps, threshold=threshold, scale=scale)
        eps_pred = Fourier_filter(eps_pred, threshold=threshold, scale=scale)

    x_sec_list.append(eps)
    x_sec_recon_list.append(eps_pred)


@torch.no_grad()
def loss_mi(model, batch, vae, text_encoder, device, x_sec_list, x_sec_recon_list, Filter, threshold, scale, scores, highs):
    # target_steps = list(range(0, flags.t_sec + flags.timestep, flags.timestep))[1:]  # 10 20 ... 90 100
    # starttmp = flags.t_sec

    batch["pixel_values"] = batch["pixel_values"].to(device)
    latents = vae.encode(batch["pixel_values"].to(torch.float32)).latent_dist.sample()
    latents = latents * vae.config.scaling_factor
    x = latents

    embd_cond = text_encoder(batch["input_ids"].to(device))[0]

    def prox_loss(model, flags, x, device, encoder_hidden_states, t=500): # 500
        x = x.to(device)
        t = x.new_ones([x.shape[0], ], dtype=torch.long) * (t)
        betas = scheduler.betas.double().to(device)
        alphas = 1. - betas  ### α
        alphas = torch.cumprod(alphas, dim=0)  ### α
        alphas_t = extract(alphas, t=t, x_shape=x.shape)  ### α

        eps = torch.randn_like(x)
        eps_pred = model(alphas_t.sqrt() * x + (1 - alphas_t).sqrt() * eps, t, encoder_hidden_states).sample

        eps_pred = (alphas_t.sqrt() * x + (1 - alphas_t).sqrt() * eps - (1 - alphas_t).sqrt() * eps_pred) / alphas_t.sqrt()
        eps = x

        return eps, eps_pred

    eps, eps_pred = prox_loss(model, flags, x, device, embd_cond, t=500)

    if Filter == 1:
        eps = Fourier_filter(eps, threshold=threshold, scale=scale)
        eps_pred = Fourier_filter(eps_pred, threshold=threshold, scale=scale)

    x_sec_list.append(eps)
    x_sec_recon_list.append(eps_pred)


def run(filter=1, t=5, s=0.2, attack='naive'):
    flags.T = 1000
    flags.train_batch_size = 1
    flags.dataloader_num_workers = 0
    flags.resolution = 512
    flags.image_column = "image"
    flags.caption_column = "text"
    flags.t_sec = 100
    flags.timestep = 10
    flags.stpsnumi = 1
    flags.outdir = 'outputs'

    flags.attack = attack
    flags.dataset = 'coco'

    flags.filter = filter
    assert flags.attack in ['sec', 'naive', 'pia']
    assert flags.dataset in ['pokemon', 'laion', 'coco', 'flickr']

    flags.diff_path = "./checkpoint/sd1.4_coco_aug"
    dataset_root = "./datasets"

    init_model(flags.diff_path)
    print("loading finish!")

    if flags.dataset == 'pokemon':
        _, _, train_loader, test_loader = load_pokemon_datasets(dataset_root, num_samples=415, tokenizer=tokenizer)
    elif flags.dataset == 'laion':
        train_loader, test_loader = load_laion5_dataset(dataset_root, num_samples=100, tokenizer=tokenizer)
    elif flags.dataset == 'coco':
        _, _, train_loader, test_loader = load_coco_datasets(dataset_root, num_samples=100, tokenizer=tokenizer)
    elif flags.dataset == 'flickr':
        train_loader, test_loader = load_flickr_datasets(dataset_root, num_samples=1000, tokenizer=tokenizer)
    else:
        raise NotImplementedError

    x_sec_list = []
    x_sec_recon_list = []
    scores = []
    highs = []

    print("start attack!")
    print(flags.attack)
    print(flags.dataset)

    for step, batch in enumerate(tqdm(train_loader)):
        if flags.attack == 'naive':
            loss_mi(unet, batch, vae, text_encoder, device, x_sec_list, x_sec_recon_list, flags.filter, threshold=t, scale=s, scores=scores, highs=highs)
        elif flags.attack == 'sec':
            sec_mi(unet, batch, vae, text_encoder, device, x_sec_list, x_sec_recon_list, flags.filter, threshold=t, scale=s, scores=scores, highs=highs)
        elif flags.attack == 'pia':
            prox_mi(unet, batch, vae, text_encoder, device, x_sec_list, x_sec_recon_list, flags.filter, threshold=t, scale=s, scores=scores, highs=highs)
        else:
            print('Error, No implement!', flags.attack)
            exit()

    x_sec_s = torch.concat(x_sec_list)
    x_sec_recon_s = torch.concat(x_sec_recon_list)

    norm = 5 if flags.attack == 'prox' else 'l2'
    member_scores = att_measure(x_sec_s, x_sec_recon_s, norm, device=device).cpu()

    print("******************************************")
    x_sec_list = []
    x_sec_recon_list = []
    scores = []
    highs = []


    for step, batch in enumerate(tqdm(test_loader)):
        if flags.attack == 'sec':
            sec_mi(unet, batch, vae, text_encoder, device, x_sec_list, x_sec_recon_list, flags.filter, threshold=t, scale=s, scores=scores, highs=highs)
        elif flags.attack == 'pia':
            prox_mi(unet, batch, vae, text_encoder, device, x_sec_list, x_sec_recon_list, flags.filter, threshold=t, scale=s, scores=scores, highs=highs)
        elif flags.attack == 'naive':
            loss_mi(unet, batch, vae, text_encoder, device, x_sec_list, x_sec_recon_list, flags.filter, threshold=t, scale=s, scores=scores, highs=highs)
        else:
            print('Error, No implement!', flags.attack)
            exit()

    x_sec_s = torch.concat(x_sec_list)
    x_sec_recon_s = torch.concat(x_sec_recon_list)

    norm = 5 if flags.attack == 'prox' else 'l2'
    nonmember_scores = att_measure(x_sec_s, x_sec_recon_s, norm, device=device).cpu()

    min_score = min(member_scores.min(), nonmember_scores.min())
    max_score = max(member_scores.max(), nonmember_scores.max())

    TPR_list = []
    FPR_list = []

    TPRatFPR_1 = 0
    FPR_1_idx = 999
    TPRatFPR_01 = 0
    FPR_01_idx = 999

    total = member_scores.size(0) + nonmember_scores.size(0)
    max_acc = 0.0

    for threshold in torch.range(min_score, max_score, (max_score - min_score) / 10000):
        acc = ((member_scores <= threshold).sum() + (nonmember_scores > threshold).sum()) / total

        TP = (member_scores <= threshold).sum()
        TN = (nonmember_scores > threshold).sum()
        FP = (nonmember_scores <= threshold).sum()
        FN = (member_scores > threshold).sum()

        TPR = TP / (TP + FN)
        FPR = FP / (FP + TN)

        if FPR_1_idx > (0.01 - FPR).abs():
            FPR_1_idx = (0.01 - FPR).abs()
            TPRatFPR_1 = TPR

        if FPR_01_idx > (0.001 - FPR).abs():
            FPR_01_idx = (0.001 - FPR).abs()
            TPRatFPR_01 = TPR

        if max_acc < acc:
            max_acc = acc

        TPR_list.append(TPR.item())
        FPR_list.append(FPR.item())

    auc = metrics.auc(np.asarray(FPR_list), np.asarray(TPR_list))
    print(f'AUC: {auc} \t ASR: {max_acc} \t TPR@FPR=1%: {TPRatFPR_1} \t TPR@FPR=0.1%: {TPRatFPR_01}')

    return auc, FPR_list, TPR_list


if __name__ == '__main__':
    run(filter=0, t=5, s=0.2, attack='naive')
    run(filter=1, t=5, s=0.2, attack='naive')
    # run(filter=0, t=5, s=0.2, attack='pia')
    # run(filter=1, t=5, s=0.2, attack='pia')
    # run(filter=0, t=5, s=0.2, attack='sec')
    # run(filter=1, t=5, s=0.2, attack='sec')





