import numpy as np
import torch
import torch.nn.functional as F
import math
import argparse
from pathlib import Path
import sklearn.metrics as metrics
import os
import pandas as pd
import pickle
import re
import scipy

def parse_args():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--init_checkpoint', default=None, type=str, help='initial checkpoint')
    parser.add_argument('--reference_checkpoint', default=None, type=str, help='initial checkpoint')
    parser.add_argument('--pretrained_losses_folder', default=None, type=str, help='initial checkpoint')
    parser.add_argument('--output_dir', type=str, default="results")

    return parser.parse_args()

def ref_attack(loss, loss_ref):
    return - loss.mean(-1) / loss_ref.mean(-1)

def get_pretrained_losses(path, info):
    
    path += f"/{info['dataset']}/"
    losses_per_model = {}
    try:    
        l = os.listdir(path)
    except FileNotFoundError:
        print(f"Path {path} not found")
        return {}
    
    for model_name in l:
        if info["shadow_id"] == 0:
            members_loss_pret_ref = torch.tensor(np.load(Path(path) / model_name / f"losses_train_tokens.npy")).double()[:2_000]
            nonmembers_loss_pret_ref = torch.tensor(np.load(Path(path) / model_name / f"losses_val_tokens.npy")).double()[:2_000]
        else:
            members_loss_pret_ref = torch.tensor(np.load(Path(path) / model_name / f"losses_val_tokens.npy")).double()[:2_000]
            nonmembers_loss_pret_ref = torch.tensor(np.load(Path(path) / model_name / f"losses_train_tokens.npy")).double()[:2_000]
        
        losses_per_model[model_name] = {
            'base': torch.cat([nonmembers_loss_pret_ref, members_loss_pret_ref], dim=0),
            'z': torch.tensor(np.load(Path(path) / model_name / f"losses_z_tokens.npy")).double()
        }
        torch.manual_seed(42)
        ids = torch.randperm(len(losses_per_model[model_name]['z']))[:500]
        losses_per_model[model_name]['z'] = losses_per_model[model_name]['z'][ids]

    return losses_per_model

def pretrained_ref_attack(loss, loss_ref):
    return - loss.mean(-1) / loss_ref.mean(-1)

def diff_attack(loss, loss_ref):
    return - loss.mean(-1) / loss_ref.mean(-1)

def loss_attack(loss):
    return -loss.mean(-1)

def mink_attack(loss, ratio):
    return (-loss).sort(-1).values[:,:int(loss.shape[-1]*ratio)].mean(-1)

def rmia(loss, loss_ref, z_loss, z_loss_ref, a, gamma):
    p = -loss.sum(-1)
    p_ref = -loss_ref.sum(-1)
    p_z = -z_loss.sum(-1)
    p_z_ref = -z_loss_ref.sum(-1)
    
    if a == 0:
        p_x = p_ref
    else:
        p_x = torch.logsumexp(torch.stack([p_ref + math.log(1+a), torch.full(p_ref.shape, math.log(1-a), dtype=p_ref.dtype)], dim=0), dim=0) - math.log(2)
    
    ratio_x = p - p_x
    ratio_z = p_z - p_z_ref
    score  = (ratio_x.unsqueeze(0) - ratio_z.unsqueeze(1) > math.log(gamma)).float().mean(0)
    return score

def rmia_pop(loss, loss_ref, z_loss, z_loss_ref, a, gamma):
    torch.manual_seed(42)
    ids = torch.randperm(len(loss))[:500]
    
    return rmia(loss, loss_ref, loss[ids], loss_ref[ids], a, gamma)

def score_attack(score, labels):
    fpr, tpr, thresholds = metrics.roc_curve(labels, score)
    return {
        'AUC': metrics.roc_auc_score(labels, score),
        'TPR @ 1% FPR': tpr[np.abs(fpr - 0.01).argmin()],
        'TPR @ 5% FPR': tpr[np.abs(fpr - 0.05).argmin()],
    }

reg = re.compile(r"(\d+)e-(\d+)")
def get_checkpoint_info(checkpoint):
    name = reg.sub(r"\1e~\2", checkpoint.split("/")[-1])
    info = {}    
    for i, pos in enumerate(['dataset', 'setting', 'lr', 'epoch', 'batch_size', 'target_epsilon', 'prefix_length', 'prefix_type', 'shadow_id']):
        info[pos] = name.split('-')[i].replace('~', '-')
    return info


def main():
    args = parse_args()
    print(args)
    os.makedirs(args.output_dir, exist_ok=True)
    info = get_checkpoint_info(args.init_checkpoint)
    print(info)
    
    members_loss = torch.tensor(np.load(Path(args.init_checkpoint) / f"losses_train_tokens.npy")).double()[:2_000]
    nonmembers_loss = torch.tensor(np.load(Path(args.init_checkpoint) / f"losses_val_tokens.npy")).double()[:2_000]
    z_loss = torch.tensor(np.load(Path(args.init_checkpoint) / f"losses_z_tokens.npy")).double()
    
    nonmembers_loss_ref = torch.tensor(np.load(Path(args.reference_checkpoint) / f"losses_train_tokens.npy")).double()[:2_000]
    members_loss_ref = torch.tensor(np.load(Path(args.reference_checkpoint) / f"losses_val_tokens.npy")).double()[:2_000]
    z_loss_ref = torch.tensor(np.load(Path(args.reference_checkpoint) / f"losses_z_tokens.npy")).double()

    pretraining_losses = get_pretrained_losses(args.pretrained_losses_folder, info)

    torch.manual_seed(42)
    ids = torch.randperm(len(z_loss_ref))[:500]
    z_loss_ref = z_loss_ref[ids]
    z_loss = z_loss[ids]

    print(members_loss.shape, nonmembers_loss.shape, z_loss.shape)
    print(members_loss_ref.shape, nonmembers_loss_ref.shape, z_loss_ref.shape)
    
    assert len(members_loss) == len(members_loss_ref), f"{len(members_loss)} != {len(members_loss_ref)}"
    assert len(nonmembers_loss) == len(nonmembers_loss_ref), f"{len(nonmembers_loss)} != {len(nonmembers_loss_ref)}"
    assert len(z_loss) == len(z_loss_ref), f"{len(z_loss)} != {len(z_loss_ref)}"
        
    
    labels = torch.cat([torch.ones(len(members_loss)), torch.zeros(len(nonmembers_loss))])
    loss = torch.cat([members_loss, nonmembers_loss], dim=0)
    loss_ref = torch.cat([members_loss_ref, nonmembers_loss_ref], dim=0)
    
    for pre_model, pre_losses in pretraining_losses.items():
        assert len(loss_ref) == len(pre_losses['base']), f"[{pre_model}]: {len(loss_ref)} != {len(pre_losses['base'])}"
        assert len(z_loss) == len(pre_losses['z']), f"[{pre_model}]: {len(z_loss)} != {len(pre_losses['z'])}"
    
    print(loss.shape, loss_ref.shape, labels.shape)
    scores = {}
    
    scores['loss'] = loss_attack(loss)
    scores['ref'] = ref_attack(loss, loss_ref)
    for pre_model, pre_losses in pretraining_losses.items():
        scores[f'ref_{pre_model}'] = pretrained_ref_attack(loss, pre_losses['base'])
    scores['diff'] = diff_attack(loss, loss_ref)
    for ratio in [0.25, 0.5, 1.0]:
        scores[f'mink_{ratio}'] = mink_attack(loss, ratio)
    

    for a in [0.0, 0.1]:
        for gamma in [1.0, 2.0, 3.0, 4.0]:
            for pre_model, pre_losses in pretraining_losses.items():
                scores[f'rmia_pop_{pre_model}_{a}_{gamma}'] = rmia_pop(loss, pre_losses['base'], z_loss, pre_losses['z'], a, gamma)
            scores[f'rmia_pop_{a}_{gamma}'] = rmia_pop(loss, loss_ref, z_loss, z_loss_ref, a, gamma)

    
    df = {}
    for name, score in scores.items():
        d = score_attack(score, labels)
        d['score'] = name
        for k, v in d.items():
            df.setdefault(k, []).append(v)      
    
    df = pd.DataFrame(df).sort_values('AUC', ascending=False)
    pd.to_pickle(df, Path(args.output_dir) / "mia_results.pkl")
    print(df)
    
    scores['labels'] = labels
    pickle.dump(scores, open(Path(args.output_dir) / "mia_scores.pkl", "wb"))
