import numpy as np
import random
import torch
import components
from typing import Type, Dict
from model import UNet
from dataset_utils import load_member_data
from torchmetrics.classification import BinaryAUROC, BinaryROC
from tqdm import tqdm


def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


def get_FLAGS():

    def FLAGS(x): return x
    FLAGS.T = 1000
    FLAGS.ch = 128
    FLAGS.ch_mult = [1, 2, 2, 2]
    FLAGS.attn = [1]
    FLAGS.num_res_blocks = 2
    FLAGS.dropout = 0.1
    FLAGS.beta_1 = 0.0001
    FLAGS.beta_T = 0.02

    return FLAGS


def get_model(ckpt, WA=True):
    FLAGS = get_FLAGS()
    model = UNet(
        T=FLAGS.T, ch=FLAGS.ch, ch_mult=FLAGS.ch_mult, attn=FLAGS.attn,
        num_res_blocks=FLAGS.num_res_blocks, dropout=FLAGS.dropout)
    # load model and evaluate
    ckpt = torch.load(ckpt)

    if WA:
        weights = ckpt['ema_model']
    else:
        weights = ckpt['net_model']

    new_state_dict = {}
    for key, val in weights.items():
        if key.startswith('module.'):
            new_state_dict.update({key[7:]: val})
        else:
            new_state_dict.update({key: val})

    model.load_state_dict(new_state_dict)

    model.eval()

    return model


class EpsGetter(components.EpsGetter):
    def __call__(self, xt: torch.Tensor, condition: torch.Tensor = None, noise_level=None, t: int = None) -> torch.Tensor:
        t = torch.ones([xt.shape[0]], device=xt.device).long() * t
        return self.model(xt, t=t)


attackers: Dict[str, Type[components.DDIMAttacker]] = {
    "sec": components.SecMIAttacker,
    "pia": components.PIA,
    "naive": components.NaiveAttacker,
}

DEVICE = 'cuda'

@torch.no_grad()
def DDIM_Attack(checkpoint='experiments/CIFAR100/checkpoint.pt',
         dataset='CIFAR100',
         attacker_name="naive",
         Filter=0,
         t=5, s=0.2,
         attack_num=1, interval=100, # sec : 10 10
         seed=0):

    set_seeds(seed)
    FLAGS = get_FLAGS()


    print("loading model...")
    model = get_model(checkpoint, WA=True).to(DEVICE)
    model.eval()

    print("loading dataset...")
    if dataset == 'TINY-IN':
        _, _, train_loader, test_loader = load_member_data(dataset_name='TINY-IN', batch_size=64,
                                                           shuffle=False, randaugment=False)
    if dataset == 'CIFAR100':
        _, _, train_loader, test_loader = load_member_data(dataset_name='CIFAR100', batch_size=64,
                                                           shuffle=False, randaugment=False)
    if dataset == 'STL10-U':
        _, _, train_loader, test_loader = load_member_data(dataset_name='STL10-U', batch_size=64,
                                                          shuffle=False, randaugment=False)

    attacker = attackers[attacker_name](
        torch.from_numpy(np.linspace(FLAGS.beta_1, FLAGS.beta_T, FLAGS.T)).to(DEVICE), interval, attack_num, EpsGetter(model), lambda x: x * 2 - 1, Filter=Filter, t=t, s=s)

    print("attack start...")
    members, nonmembers = [], []
    for member, nonmember in tqdm(zip(train_loader, test_loader), total=len(train_loader)):
        member, nonmember = member[0].to(DEVICE), nonmember[0].to(DEVICE)

        members.append(attacker(member))
        nonmembers.append(attacker(nonmember))

        members = [torch.cat(members, dim=-1)]
        nonmembers = [torch.cat(nonmembers, dim=-1)]

    member = members[0]
    nonmember = nonmembers[0]

    auroc = [BinaryAUROC().cuda()(torch.cat([member[i] / max([member[i].max().item(), nonmember[i].max().item()]), nonmember[i] / max([member[i].max().item(), nonmember[i].max().item()])]), torch.cat([torch.zeros(member.shape[1]).long(), torch.ones(nonmember.shape[1]).long()]).cuda()).item() for i in range(member.shape[0])]
    tpr_fpr = [BinaryROC().cuda()(torch.cat([1 - nonmember[i] / max([member[i].max().item(), nonmember[i].max().item()]), 1 - member[i] / max([member[i].max().item(), nonmember[i].max().item()])]), torch.cat([torch.zeros(member.shape[1]).long(), torch.ones(nonmember.shape[1]).long()]).cuda()) for i in range(member.shape[0])]
    tpr_fpr_1 = [i[1][(i[0] < 0.01).sum() - 1].item() for i in tpr_fpr]
    cp_tpr_fpr_1 = tpr_fpr_1[:]

    print('auc', auroc)
    print('tpr @ 1% fpr', cp_tpr_fpr_1)


    n = member.shape[0]
    asr_list = []

    for i in range(n):
        member_scores = member[i, :]
        nonmember_scores = nonmember[i, :]

        min_score = min(member_scores.min(), nonmember_scores.min()).item()
        max_score = max(member_scores.max(), nonmember_scores.max()).item()

        best_asr = 0
        for threshold in torch.arange(min_score, max_score, (max_score - min_score) / 2000):

            TP = (member_scores <= threshold).sum().item()
            TN = (nonmember_scores > threshold).sum().item()
            FP = (nonmember_scores <= threshold).sum().item()
            FN = (member_scores > threshold).sum().item()

            ASR = (TP + TN) / (TP + TN + FP + FN)
            if ASR > best_asr:
                best_asr = ASR

        asr_list.append(best_asr)

    print("ASR list:", asr_list)



if __name__ == '__main__':
    DDIM_Attack(Filter=0, t=5, s=0.2, attacker_name="naive")
    DDIM_Attack(Filter=1, t=5, s=0.2, attacker_name="naive")
    # DDIM_Attack(Filter=0, t=5, s=0.2, attacker_name="pia")
    # DDIM_Attack(Filter=1, t=5, s=0.2, attacker_name="pia")
    # DDIM_Attack(Filter=0, t=5, s=0.2, attacker_name="sec")
    # DDIM_Attack(Filter=1, t=5, s=0.2, attacker_name="sec")
