from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image
import sys
import numpy as np
import torch
from diffusers import StableDiffusionPipeline, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import UNet2DConditionModel, PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler
from tqdm.auto import tqdm
from datasets import load_dataset
import time
import os



class Flag(object):
    pass


timestart = time.strftime('%m%d_%H%M%S', time.localtime()).split()[0]
device = "cuda" if torch.cuda.is_available() else "cpu"
print('device ---- ', device)
flags = Flag

unet = None
tokenizer = None
text_encoder = None
scheduler = None
vae = None


def init_model(diff_path):
    global unet, tokenizer, text_encoder, scheduler, vae
    ### LOAD MODEL ########
    print('init model...')
    vae = AutoencoderKL.from_pretrained(
        diff_path, subfolder='vae', use_auth_token=True)
    print('vae loaded.')
    # vae = vae.float()

    tokenizer = CLIPTokenizer.from_pretrained(diff_path, subfolder="tokenizer", )
    text_encoder = CLIPTextModel.from_pretrained(diff_path, subfolder="text_encoder", )
    # text_encoder = text_encoder.float()
    print('tokenizer, textencoder loaded.')

    unet = UNet2DConditionModel.from_pretrained(
        diff_path,
        subfolder='unet', )  #
    print('unet loaded.')

    scheduler = DDIMScheduler.from_pretrained(diff_path, subfolder="scheduler")
    print('sch loaded.', scheduler)

    vae = vae.to(device)
    vae.eval()
    text_encoder = text_encoder.to(device)
    unet.eval()
    unet = unet.to(device)
    unet.eval()
    print('all model loaded.')


def get_data(flags=flags, dataset_name=None):
    '''
    Loading data
    '''
    assert dataset_name != None

    # DataLoaders creation:
    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
        input_ids = torch.stack([example["input_ids"] for example in examples])
        input_ids_0 = torch.stack([example["input_ids_0"] for example in examples])
        input_ids_1 = torch.stack([example["input_ids_1"] for example in examples])
        input_ids_2 = torch.stack([example["input_ids_2"] for example in examples])
        input_ids_null = torch.stack([example["input_ids_null"] for example in examples])

        return {"pixel_values": pixel_values, "input_ids": input_ids, "input_ids_0": input_ids_0,
                "input_ids_1": input_ids_1, "input_ids_2": input_ids_2, "input_ids_null": input_ids_null, }

    # Preprocessing the datasets.
    train_transforms = transforms.Compose(
        [
            transforms.Resize(flags.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(flags.resolution),  # if args.center_crop else transforms.RandomCrop(args.resolution),
            # transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    flags.dataset_config_name = None
    flags.cache_dir = None
    flags.train_data_dir = None
    import os

    if dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(
            dataset_name,
            flags.dataset_config_name,
            cache_dir=flags.cache_dir,
            data_dir=flags.train_data_dir,
        )
    else:
        data_files = {}
        if flags.train_data_dir is not None:
            data_files["train"] = os.path.join(flags.train_data_dir, "**")
        dataset = load_dataset(
            "imagefolder",
            data_files=data_files,
            cache_dir=flags.cache_dir,
        )
        # See more about loading custom images at
        # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder

    # Preprocessing the datasets.
    # We need to tokenize inputs and targets.
    column_names = dataset["train"].column_names

    DATASET_NAME_MAPPING = {
        "lambdalabs/pokemon-blip-captions": ("image", "text"),
    }

    # Get the column names for input/target.
    dataset_columns = DATASET_NAME_MAPPING.get(dataset_name, None)
    if flags.image_column is None:
        image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
    else:
        image_column = flags.image_column
        if image_column not in column_names:
            raise ValueError(
                f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
            )
    if flags.caption_column is None:
        caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
    else:
        caption_column = flags.caption_column
        if caption_column not in column_names:
            raise ValueError(
                f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
            )

    import random

    print('---\ncaption_column:', caption_column, '\n---')

    def tokenize_captions_multi(examples, is_train=True):
        captions = []
        for caption in examples[caption_column]:
            if isinstance(caption, str):
                captions.append(caption)
            elif isinstance(caption, (list, np.ndarray)):
                # take a random caption if there are multiple
                captions.append(random.choice(caption) if is_train else caption[0])
            else:
                raise ValueError(
                    f"Caption column `{caption_column}` should contain either strings or lists of strings."
                )
        inputs = tokenizer(
            captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
        )

        inputs_0 = tokenizer(
            [e[:int(len(e) / 3)] for e in captions], max_length=tokenizer.model_max_length, padding="max_length",
            truncation=True, return_tensors="pt"
        )

        inputs_1 = tokenizer(
            [e[int(len(e) / 3):int(2 * len(e) / 3)] for e in captions], max_length=tokenizer.model_max_length,
            padding="max_length",
            truncation=True, return_tensors="pt"
        )

        inputs_2 = tokenizer(
            [e[int(2 * len(e) / 3):] for e in captions], max_length=tokenizer.model_max_length,
            padding="max_length",
            truncation=True, return_tensors="pt"
        )

        inputs_null = tokenizer(
            ["" for e in captions], max_length=tokenizer.model_max_length,
            padding="max_length",
            truncation=True, return_tensors="pt"
        )

        return inputs.input_ids, inputs_0.input_ids, inputs_1.input_ids, inputs_2.input_ids, inputs_null.input_ids

    def preprocess_train_multi(examples):
        images = [image.convert("RGB") for image in examples[image_column]]
        examples["pixel_values"] = [train_transforms(image) for image in images]
        examples["input_ids"], examples["input_ids_0"], examples["input_ids_1"], examples["input_ids_2"], examples[
            "input_ids_null"] = tokenize_captions_multi(examples)
        return examples

    test_dataset = dataset["train"].with_transform(preprocess_train_multi)

    subset_indices = range(2500)
    from torch.utils.data import Subset
    subset_dataset = Subset(test_dataset, subset_indices)

    test_dataloader = torch.utils.data.DataLoader(
        subset_dataset,
        shuffle=False,
        collate_fn=collate_fn,
        batch_size=flags.train_batch_size,
        num_workers=flags.dataloader_num_workers,
    )

    # test_dataloader = torch.utils.data.DataLoader(
    #     test_dataset,
    #     shuffle=False,
    #     collate_fn=collate_fn,
    #     batch_size=flags.train_batch_size,
    #     num_workers=flags.dataloader_num_workers,
    # )

    return test_dataloader


# CNT = 0
@torch.no_grad()
def extract(v, t, x_shape):
    """
    Extract some coefficients at specified timesteps, then reshape to
    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    out = torch.gather(v, index=t, dim=0).float()
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))


@torch.no_grad()
def ddim_singlestep(model, FLAGS, x, t_c, t_target, requires_grad=False, device='cuda', encoder_hidden_states=None):
    # global CNT
    if encoder_hidden_states == None:
        print('Error! encoder_hidden_states==None')

    x = x.to(device)
    t_c = x.new_ones([x.shape[0], ], dtype=torch.long) * (t_c)
    t_target = x.new_ones([x.shape[0], ], dtype=torch.long) * (t_target)
    ### 改成SD的参数
    betas = scheduler.betas.double().to(device)
    alphas = 1. - betas
    alphas = torch.cumprod(alphas, dim=0)
    alphas_t_c = extract(alphas, t=t_c, x_shape=x.shape)
    alphas_t_target = extract(alphas, t=t_target, x_shape=x.shape)

    if requires_grad:
        epsilon = model(x, t_c, encoder_hidden_states).sample
        # CNT += 1
        # print('ddim_singlestep CNT', CNT)
    else:
        with torch.no_grad():
            epsilon = model(x, t_c, encoder_hidden_states).sample
            # CNT+=1
            # print('ddim_singlestep CNT',CNT)

    pred_x_0 = (x - ((1 - alphas_t_c).sqrt() * epsilon)) / (alphas_t_c.sqrt())
    x_t_target = alphas_t_target.sqrt() * pred_x_0 \
                 + (1 - alphas_t_target).sqrt() * epsilon

    return {
        'x_t_target': x_t_target,
        'epsilon': epsilon
    }


@torch.no_grad()
def ddim_multistep(model, FLAGS, x, t_c, target_steps, clip=False, device='cuda', requires_grad=False,
                   encoder_hidden_states=None):
    for idx, t_target in enumerate(target_steps):
        result = ddim_singlestep(model, FLAGS, x, t_c, t_target, requires_grad=requires_grad, device=device,
                                 encoder_hidden_states=encoder_hidden_states)
        x = result['x_t_target']
        t_c = t_target

    if clip:
        result['x_t_target'] = torch.clip(result['x_t_target'], -1, 1)

    return result


@torch.no_grad()
def att_measure(diffusion, sample, metric, device='cuda'):
    diffusion = diffusion.to(device).float()
    sample = sample.to(device).float()

    if len(diffusion.shape) == 5:
        num_timestep = diffusion.size(0)
        diffusion = diffusion.permute(1, 0, 2, 3, 4).reshape(-1, num_timestep * 3, 32, 32)
        sample = sample.permute(1, 0, 2, 3, 4).reshape(-1, num_timestep * 3, 32, 32)

    if metric == 'l2':
        score = ((diffusion - sample) ** 2).flatten(1).sum(dim=-1)
    elif isinstance(metric, int):
        score = (torch.abs(diffusion - sample) ** metric).flatten(1).sum(dim=-1)
    else:
        raise NotImplementedError

    return score


@torch.no_grad()
def sec_mi(model, batch, vae, text_encoder, device, x_sec_list, x_sec_recon_list):
    target_steps = list(range(0, flags.t_sec + flags.timestep, flags.timestep))[1:]  # 10 20 ... 90 100
    starttmp = flags.t_sec

    batch["pixel_values"] = batch["pixel_values"].to(device)
    latents = vae.encode(batch["pixel_values"].to(torch.float32)).latent_dist.sample()
    latents = latents * vae.config.scaling_factor
    x = latents
    embd_cond = text_encoder(batch["input_ids"].to(device))[0]

    x_sec = ddim_multistep(model, flags, x, t_c=0, target_steps=target_steps, device=device,
                           encoder_hidden_states=embd_cond)
    x_sec = x_sec['x_t_target']  #

    endtmp = starttmp + flags.stpsnumi

    forw_stps = list(range(starttmp, endtmp + 1))
    back_stps = list(reversed(list(range(starttmp, endtmp + 1))))

    embd = text_encoder(batch["input_ids"].to(device))[0]
    # print("forw_stps:", forw_stps)
    # print("back_stps:", back_stps)
    '''
    forw_stps: [100, 101]
    back_stps: [101, 100]
    '''
    assert forw_stps[0] == 100 and forw_stps[1] == 101

    x_sec_forw = ddim_singlestep(model, flags, x_sec,
                                 t_c=forw_stps[0], t_target=forw_stps[1],
                                 device=device, encoder_hidden_states=embd)


    x_sec_recon = ddim_singlestep(model, flags, x_sec_forw['x_t_target'],
                                  t_c=back_stps[0], t_target=back_stps[1],
                                  device=device, encoder_hidden_states=embd)
    x_sec_recon = x_sec_recon['x_t_target']

    x_sec_list.append(x_sec)
    x_sec_recon_list.append(x_sec_recon)
    # exit()


@torch.no_grad()
def prox_mi(model, batch, vae, text_encoder, device, x_sec_list, x_sec_recon_list):
    # target_steps = list(range(0, flags.t_sec + flags.timestep, flags.timestep))[1:]  # 10 20 ... 90 100
    # starttmp = flags.t_sec

    batch["pixel_values"] = batch["pixel_values"].to(device)
    latents = vae.encode(batch["pixel_values"].to(torch.float32)).latent_dist.sample()
    latents = latents * vae.config.scaling_factor
    x = latents

    embd_cond = text_encoder(batch["input_ids"].to(device))[0]

    def prox_loss(model, flags, x, device, encoder_hidden_states, t=500, ):
        x = x.to(device)
        t = x.new_ones([x.shape[0], ], dtype=torch.long) * (t)
        betas = scheduler.betas.double().to(device)
        alphas = 1. - betas  ### α
        alphas = torch.cumprod(alphas, dim=0)  ### α
        alphas_t = extract(alphas, t=t, x_shape=x.shape)  ### α




        ### EPS
        eps = model(x, 0, encoder_hidden_states).sample

        ###
        eps_pred = model(alphas_t.sqrt() * x + (1 - alphas_t).sqrt() * eps, t, encoder_hidden_states).sample

        # pass
        return eps, eps_pred

    # embd = text_encoder(embd_cond.to(device))[0]

    eps, eps_pred = prox_loss(model, flags, x, device, embd_cond, t=500, )

    # print(x_sec.shape)
    x_sec_list.append(eps)
    x_sec_recon_list.append(eps_pred)


def get_asr_acu_tpr1():
    print('\n-------- cal asr auc ------------\n')
    # from cal_and_draw_th import get_ori_data
    # from cal_and_draw_th import deal_data_first
    from sklearn.metrics import accuracy_score, roc_auc_score, roc_curve

    datas = []
    labels = []
    with open(flags.outdir + output_paths[0], 'r') as ftrain:
        lines = ftrain.readlines()[1:]
        float_list = [float(line.split('\t')[0].strip()) for line in lines]

        label_list = [0] * len(float_list)
        datas.extend(float_list)
        labels.extend(label_list)

    with open(flags.outdir + output_paths[1], 'r') as ftest:
        lines = ftest.readlines()[1:]
        float_list = [float(line.split('\t')[0].strip()) for line in lines]
        label_list = [1] * len(float_list)
        datas.extend(float_list)
        labels.extend(label_list)

    print('len(datas), len(labels):', len(datas), len(labels))


    best_threshold = None
    best_accuracy = 0.0

    min_threshold = min(datas)
    max_threshold = max(datas)
    threshold_step = (max_threshold - min_threshold) / 2000

    for threshold in list(np.arange(min_threshold, max_threshold, threshold_step)):
        predicted_values = [1 if value > threshold else 0 for value in datas]

        accuracy = accuracy_score(labels, predicted_values)

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_threshold = threshold

    print('\n**************:', output_paths)
    print('name', flags.attack, '|   best_accuracy, best_threshold, th% :', best_accuracy, best_threshold,
          (best_threshold - min_threshold) / (max_threshold - min_threshold))

    auc = roc_auc_score(labels, [(e - min_threshold) / (max_threshold - min_threshold) for e in datas])
    print('name', flags.attack, "|    AUC Score:", auc)

    fpr, tpr, _ = roc_curve(labels, [(e - min_threshold) / (max_threshold - min_threshold) for e in datas])
    idx_1_percent_fpr = next(i for i, fpr_value in enumerate(fpr) if fpr_value >= 0.01)
    tpr_at_1_percent_fpr = tpr[idx_1_percent_fpr]

    print('name', flags.attack, "|   tpr_at_1_percent_fpr:", tpr_at_1_percent_fpr)

    return best_accuracy, auc, tpr_at_1_percent_fpr


####
def init_paths(task_name):
    # global flags
    diff_path = {
        "xx":"xx"

    }[task_name]

    # flags.diff_path = diff_path

    train_data_dict = {
        "xx":"xx"


    }

    if 'coco' in task_name:
        if 'ori' in task_name:
            dataset_train_name = train_data_dict['coco_ori']
        elif 'split1' in task_name:
            dataset_train_name = train_data_dict['coco_split1']
    elif 'flickr' in task_name:
        if 'ori' in task_name:
            dataset_train_name = train_data_dict['flickr_ori']
        elif 'split1' in task_name:
            dataset_train_name = train_data_dict['flickr_split1']
    elif 'pok' in task_name:
        if 'ori' in task_name:
            dataset_train_name = train_data_dict['pokemon_ori']
        elif 'split1' in task_name:
            dataset_train_name = train_data_dict['pokemon_split1']
    # else:
    #     print('error, not train data')
    #     exit()
    else:
        dataset_train_name = train_data_dict[Use_data_model_name]

    test_data_dict = {
        "xx":"xx"

    }

    if 'coco' in task_name:
        if 'ori' in task_name:
            dataset_test_name = test_data_dict['coco_ori']
        elif 'split1' in task_name:
            dataset_test_name = test_data_dict['coco_split1']
    elif 'flickr' in task_name:
        if 'ori' in task_name:
            dataset_test_name = test_data_dict['flickr_ori']
        elif 'split1' in task_name:
            dataset_test_name = test_data_dict['flickr_split1']
    elif 'pok' in task_name:
        if 'ori' in task_name:
            dataset_test_name = test_data_dict['pokemon_ori']
        elif 'split1' in task_name:
            dataset_test_name = test_data_dict['pokemon_split1']
    else:
        dataset_test_name = test_data_dict[Use_data_model_name]
    # else:
    #     print('error, not test data')
    #     exit()

    return diff_path, dataset_train_name, dataset_test_name


torch.no_grad()
if __name__ == '__main__':
    flags.T = 1000
    flags.train_batch_size = 16
    flags.dataloader_num_workers = 0
    flags.resolution = 512
    flags.image_column = "image"
    flags.caption_column = "text"
    flags.t_sec = 100
    flags.timestep = 10
    flags.stpsnumi = 1
    flags.outdir = 'outputs'

    '''
    pok_real_split1
    pok_real_ori
    
    pok_overfit_split1
    pok_overfit_ori
    '''

    for Use_data_model_name, Attak in zip(
            [
                # "pok_real_split1", "pok_real_ori", "pok_overfit_split1", "pok_overfit_ori",
                # "pok_real_split1", "pok_real_ori", "pok_overfit_split1", "pok_overfit_ori"

                # "flickr_overfit_ori", "flickr_overfit_split1", "coco_overfit_ori", "coco_overfit_split1", "coco_real_ori", "coco_real_split1",
                # "flickr_overfit_ori", "flickr_overfit_split1", "coco_overfit_ori", "coco_overfit_split1",  "coco_real_ori", "coco_real_split1",

                # "coco_savelist_5000_ori", "coco_savelist_5000_split1",  "coco_savelist_25000_ori", "coco_savelist_25000_split1", "coco_savelist_125000_ori", "coco_savelist_125000_split1", "coco_savelist_150000_ori",
                #                 "coco_savelist_150000_split1",
                # "coco_savelist_5000_ori", "coco_savelist_5000_split1",  "coco_savelist_25000_ori", "coco_savelist_25000_split1", "coco_savelist_75000_ori", "coco_savelist_75000_split1", "coco_savelist_100000_ori", "coco_savelist_100000_split1","coco_savelist_125000_ori", "coco_savelist_125000_split1", "coco_savelist_150000_ori", "coco_savelist_150000_split1"
                # "my_wacv", "my_wacv",
                # "laion_cc", "laion_cc"
                # "coco_norand_50k_ori", "coco_norand_50k_split1",
                # "coco_norand_50k_ori", "coco_norand_50k_split1",


                # "flick_real_ori",
                # "flick_real_split1",
                # "flick_real_ori", "flick_real_split1",

            ],
            [

                "sec",
                "sec",
                "pia",
                "pia",

            ]
    ):
        print('\n====================================================================================')
        flags.attack = Attak
        flags.task_name = Use_data_model_name
        # Model_name = Use_data_model_name
        assert flags.attack in ['sec', 'noise', 'pia', ]
        print('flags.attack, flags.task_name:', flags.attack, flags.task_name)

        flags.diff_path, flags.dataset_train_name, flags.dataset_test_name = init_paths(flags.task_name)
        init_model(flags.diff_path)

        Template_name = flags.dataset_train_name.split('/')[-1].replace('train', '').replace('test', '')
        Time = timestart

        print(str(
            flags.__dict__) + '\n' + flags.diff_path + '\n' + flags.attack + '\n' + Template_name + '-------' + '\n')

        loader_flag = 0
        output_paths = []
        for data_name in [flags.dataset_train_name, flags.dataset_test_name]:
            x_sec_list = []
            x_sec_recon_list = []

            if loader_flag == 0:
                trainOrtest = "train"
                loader_flag += 1
            else:
                trainOrtest = "test"
            loader = get_data(flags, data_name)

            print("*** trainOrtest ***  ", trainOrtest)

            for step, batch in enumerate(tqdm(loader)):
                # if step>3:break
                if flags.attack == 'sec':
                    sec_mi(unet, batch, vae, text_encoder, device, x_sec_list, x_sec_recon_list)
                # elif flags.attack == 'noise':
                #     noise_mi(model, batch, vae, text_encoder, device,  x_sec_list, x_sec_recon_list)
                elif flags.attack == 'pia':
                    prox_mi(unet, batch, vae, text_encoder, device, x_sec_list, x_sec_recon_list)
                else:
                    print('Error, No implement!', flags.attack)
                    exit()

            x_sec_s = torch.concat(x_sec_list)  ## 1
            x_sec_recon_s = torch.concat(x_sec_recon_list)  ## 2

            norm = 5 if flags.attack == 'prox' else 'l2'
            print('****   norm {} *****'.format(norm))
            scores = att_measure(x_sec_s, x_sec_recon_s, norm, device=device).cpu()
            scores = scores.numpy().tolist()

            print(scores[:3])

            path_temp = flags.outdir + '/Atk_{}_M_{}_DATA_{}_TRTE_{}_T_{}.txt'.format(flags.attack, flags.task_name,
                                                                                      Template_name,
                                                                                      trainOrtest, Time)
            with open(path_temp, 'w', encoding='utf8') as f:
                f.write(str(flags.__dict__) + '\t' + flags.diff_path + '\t' + '\n')
                # for i in range(len(scores)):  # i: N = samples number
                f.write('\n'.join(['{:.5g}'.format(e) for e in scores]))

            output_paths.append(path_temp)

