import argparse
import os
import torch
import yaml
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import audio as Audio
from utils.model import get_model, get_param_num, set_noise_schedule
from utils.tools import to_device, log, synth_one_sample
from dataset import Dataset
from evaluate import evaluate
from itertools import chain
from rich.progress import track
from pytorch_lightning import seed_everything
import json
import numpy as np
import matplotlib.pyplot as plt
from torchmetrics.classification import AUROC, ROC
import pdb

from model import WaveGrad2
from model import WaveGrad2Loss
from model.modules import (
    TextEncoder,
    DurationPredictor,
    RangeParameterPredictor,
    GaussianUpsampling,
    SamplingWindow,
)
from wavegrad import WaveGrad
from utils.tools import get_mask_from_lengths
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("GPU: ", device)


def recon_score(
        wg_model, 
        Loss,
        attack_num, 
        interval, 
        batch,
        ):
    
    for m_batch in batch:
        m_batch = to_device(m_batch, device)
        output = wg_model(*(m_batch[2:]))
        losses = Loss(m_batch, output)
        duration_loss = losses[2]

        return duration_loss

@torch.no_grad()
def main(args,
         configs,
         dataset="lj",
         attacker_name="durmi",
         attack_num=100, 
         interval=10,
         seed=0,
         batch_size=1): 
    
    T = 1000
    seed_everything(seed)
    preprocess_config, model_config, train_config = configs
    
    Loss = WaveGrad2Loss(preprocess_config, model_config).to(device)
    print("loading dataset ...")
    member_dataset = Dataset(
        "member.txt", preprocess_config, train_config, sort=True, drop_last=False
    )
    nonmember_dataset = Dataset(
        "nonmember.txt", preprocess_config, train_config, sort=True, drop_last=False
    )

    print("member set:", len(member_dataset), " / non-member set:", len(nonmember_dataset))

    member_loader = DataLoader(
        member_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=member_dataset.collate_fn,
    )

    nonmember_loader = DataLoader(
        nonmember_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=member_dataset.collate_fn,
    )

    wg_model = get_model(args, configs, device, train=False)
    wg_model.eval()
    Loss = WaveGrad2Loss(preprocess_config, model_config).to(device)

    noise_schedule_path = os.path.join(
        train_config["path"]["noise_schedule_path"], "{}iters.pt".format(train_config["window"]["noise_iter"])
    )
    print("noise schdule: ", noise_schedule_path)
    set_noise_schedule(wg_model, noise_schedule_path)



    member_dur_losses, nonmember_dur_losses = [], []
    count = 0

    for member, nonmember in track(zip(member_loader, chain(*([nonmember_loader]))), total=len(nonmember_loader)):

        dur_loss_m = recon_score(wg_model, Loss, attack_num, interval / T * attack_num, member)
        dur_loss_nm = recon_score(wg_model, Loss, attack_num, interval / T * attack_num, nonmember)

        member_dur_losses.append(dur_loss_m.unsqueeze(0))
        nonmember_dur_losses.append(dur_loss_nm.unsqueeze(0))

        count += 1

    member = torch.cat(member_dur_losses, dim=0)
    nonmember = torch.cat(nonmember_dur_losses, dim=0)

    print("member_dur_losses: ", member)
    print("nonmember_dur_losses: ", nonmember) 

    all_scores = torch.cat([member, nonmember], dim=0)
    all_labels = torch.cat([
        torch.zeros(member.size(0), dtype=torch.long),
        torch.ones(nonmember.size(0), dtype=torch.long)
    ], dim=0).cuda()

    auroc_val = AUROC(task="BINARY").cuda()(all_scores, all_labels).item()

    fpr, tpr, thresholds = ROC(task="BINARY").cuda()(all_scores, all_labels)
    fpr_np = fpr.detach().cpu().numpy()
    tpr_np = tpr.detach().cpu().numpy()
    idx = np.searchsorted(fpr_np, 0.01, side="right")
    if idx == 0:
        tpr_at_1fpr = tpr_np[0]
    else:
        tpr_at_1fpr = tpr_np[idx - 1]

    results = {
        "count": int(member.size(0)),
        "member_count": len(member_dur_losses),

        "auroc": auroc_val,
        "tpr_at_1fpr": float(tpr_at_1fpr),

        "fpr": fpr.tolist(),     
        "tpr": tpr.tolist(),     
        "thresholds": thresholds.tolist(), 

        "member_dur_losses": member.tolist(),      
        "nonmember_dur_losses": nonmember.tolist(),
    }

    
    output_filename = f"./wg_{attacker_name}_{dataset}.json"

    with open(output_filename, 'w') as json_file:
        json.dump(results, json_file, indent=4)


    plt.figure(figsize=(8, 6))
    dur_mem = [t.item() for t in member_dur_losses]
    dur_nonmem = [t.item() for t in nonmember_dur_losses]
    plt.hist(dur_mem, bins=50, alpha=0.5, label='Member dataset', color='red')
    plt.hist(dur_nonmem, bins=50, alpha=0.5, label='Hold-out dataset', color='gray')
    plt.xlabel("Duration Loss")
    plt.ylabel("Frequency")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"./wg_{attacker_name}_{dataset}_histogram.png")
    plt.close()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--restore_step", type=int, default=1000000)
    parser.add_argument(
        "-p",
        "--preprocess_config",
        type=str,
        required=True,
        help="path to preprocess.yaml",
    )
    parser.add_argument(
        "-m", "--model_config", type=str, required=True, help="path to model.yaml"
    )
    parser.add_argument(
        "-t", "--train_config", type=str, required=True, help="path to train.yaml"
    )
    args = parser.parse_args()

    preprocess_config = yaml.load(open(args.preprocess_config, "r"), Loader=yaml.FullLoader)
    model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
    train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
    configs = (preprocess_config, model_config, train_config)

    main(args, configs)