import argparse
import os
import torch
import yaml
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import audio as Audio
from utils.model import get_model, get_param_num, set_noise_schedule
from utils.tools import to_device, log, synth_one_sample
from dataset import Dataset
from evaluate import evaluate
from itertools import chain
from rich.progress import track
from pytorch_lightning import seed_everything
import json
import numpy as np
import matplotlib.pyplot as plt
from torchmetrics.classification import AUROC, ROC
import pdb

from model import WaveGrad2
from model.modules import (
    TextEncoder,
    DurationPredictor,
    RangeParameterPredictor,
    GaussianUpsampling,
    SamplingWindow,
)
from wavegrad import WaveGrad
from utils.tools import get_mask_from_lengths

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("GPU: ", device)

def recon_score(
        wg_model,
        attacker_name,
        attack_num, 
        interval, 
        batch,
        ):
    
    attacker = {
        'naive': wg_model.decoder.naive_attack,
        'secmi': wg_model.decoder.secmi_attack,
        'pia': wg_model.decoder.pia_attack,
        'pian': wg_model.decoder.pian_attack
    }


    for m_batch in batch:
        m_batch = to_device(m_batch, device)
        output_temp2 = wg_model.forward_decoder(*(m_batch[2:]))
        output = attacker[attacker_name](output_temp2[0], output_temp2[1], attack_num=attack_num, interval=interval)
        return output
    

@torch.no_grad()
def main(args,
         configs,
         dataset="libritts",
         attacker_name="hybrid",
         attack_num=100, interval=10,
         seed=0,
         batch_size=10):
    
    T = 1000
    seed_everything(seed)
    preprocess_config, model_config, train_config = configs
    
    print(preprocess_config)
  
    print("loading dataset ...")
    member_dataset = Dataset(
        "member.txt", preprocess_config, train_config, sort=True, drop_last=False
    )
    nonmember_dataset = Dataset(
        "nonmember.txt", preprocess_config, train_config, sort=True, drop_last=False
    )

    print("member set:", len(member_dataset), " / non-member set:", len(nonmember_dataset))

    member_loader = DataLoader(
        member_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=member_dataset.collate_fn,
    )

    nonmember_loader = DataLoader(
        nonmember_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=member_dataset.collate_fn,
    )

    wg_model = get_model(args, configs, device, train=False)
    wg_model.eval()

    # Set noise schedule
    noise_schedule_path = os.path.join(
        train_config["path"]["noise_schedule_path"], "{}iters.pt".format(train_config["window"]["noise_iter"])
    )
    print("noise schdule: ", noise_schedule_path)
    set_noise_schedule(wg_model, noise_schedule_path)

    members, nonmembers = [], []
    count = 0

    for member, nonmember in track(zip(member_loader, chain(*([nonmember_loader]))), total=len(nonmember_loader)):
        members.append(recon_score(wg_model, attacker_name, attack_num, interval, member))
        nonmembers.append(recon_score(wg_model, attacker_name, attack_num, interval, nonmember))

        count += 1

    members = [torch.cat(members, dim=-1)]
    nonmembers = [torch.cat(nonmembers, dim=-1)]
    member = members[0]
    nonmember = nonmembers[0]

    auroc = [
        AUROC(task="BINARY").cuda()(
            torch.cat([
                member[i] / max([member[i].max().item(), nonmember[i].max().item()]),
                nonmember[i] / max([member[i].max().item(), nonmember[i].max().item()])
            ]), 
            torch.cat([
                torch.zeros(member.shape[1]).long(), 
                torch.ones(nonmember.shape[1]).long()
            ]).cuda()
        ).item()
        for i in range(member.shape[0])
    ]

    tpr_fpr = [
        ROC(task="BINARY").cuda()(
            torch.cat([
                1 - nonmember[i] / max([member[i].max().item(), nonmember[i].max().item()]),
                1 - member[i] / max([member[i].max().item(), nonmember[i].max().item()])
            ]),
            torch.cat([
                torch.zeros(member.shape[1]).long(),
                torch.ones(nonmember.shape[1]).long()
            ]).cuda()
        )
        for i in range(member.shape[0])
    ]

    
    cp_auroc = auroc[:]
    cp_auroc.sort(reverse=True)

    tpr_fpr_1 = [i[1][(i[0] < 0.01).sum() - 1].item() for i in tpr_fpr]
    cp_tpr_fpr_1 = tpr_fpr_1[:]
    cp_tpr_fpr_1.sort(reverse=True)

    cp_auroc2 = auroc[:]
    cp_auroc2.sort(reverse=True)

    tpr_fpr_01 = [i[1][(i[0] < 0.001).sum() - 1].item() for i in tpr_fpr]
    cp_tpr_fpr_01 = tpr_fpr_01[:]
    cp_tpr_fpr_01.sort(reverse=True)

    avg_auroc = np.mean(auroc)
    max_auroc = np.max(auroc)

    avg_tpr_fpr_1 = np.mean(cp_tpr_fpr_1)
    max_tpr_fpr_1 = np.max(cp_tpr_fpr_1)
    
    avg_tpr_fpr_01 = np.mean(cp_tpr_fpr_01)
    max_tpr_fpr_01 = np.max(cp_tpr_fpr_01)
    
    fprs = [i[0].cpu().numpy() for i in tpr_fpr]
    tprs = [i[1].cpu().numpy() for i in tpr_fpr]
    
    min_len = min(arr.shape[0] for arr in fprs)
    fprs_trimmed = [arr[:min_len] for arr in fprs]
    avg_fprs = np.mean(fprs_trimmed, axis=0)

    min_len2 = min(arr.shape[0] for arr in tprs)
    tprs_trimmed = [arr[:min_len2] for arr in tprs]
    avg_tprs = np.mean(tprs_trimmed, axis=0)

    output_filename = f"./{attacker_name}_{dataset}.json"

    results = {
        "count": count * batch_size,
        "member": member.shape[1],
        "nonmember": nonmember.shape[1],

        "avg_auroc": avg_auroc,
        "max_auroc": max_auroc,
        "avg_tpr_fpr_1": avg_tpr_fpr_1,
        "max_tpr_fpr_1": max_tpr_fpr_1,
        
        "avg_tpr_fpr_01": avg_tpr_fpr_01,
        "max_tpr_fpr_01": max_tpr_fpr_01,
        
        "auc": auroc,
        "tpr_fpr_1": cp_tpr_fpr_1,
        "tpr_fpr_01": cp_tpr_fpr_01,

        "loss_member": member.tolist(), 
        "loss_nonmember": nonmember.tolist(),
        
        "fprs": [fpr.tolist() for fpr in fprs],
        "tprs": [tpr.tolist() for tpr in tprs],
    }


    with open(output_filename, 'w') as json_file:
        json.dump(results, json_file, indent=4)

    def plot_auroc(auroc, filename):
        plt.figure(figsize=(8, 6))
        plt.plot(range(len(auroc)), auroc, label='AUROC')
        plt.xlabel('Timestep')
        plt.ylabel('AUROC Value')
        plt.ylim(0.0, 1.0)
        plt.title('AUROC Curve')
        plt.legend()
        plt.grid(True)
        plt.savefig(f"./{filename}_auroc.png")
        plt.close()

    def plot_tpr_fpr(tpr_fpr_1, filename):
        plt.figure(figsize=(8, 6))
        plt.plot(range(len(tpr_fpr_1)), tpr_fpr_1, label='TPR @ 1% FPR')
        plt.xlabel('Timestep')
        plt.ylabel('TPR @ 1% FPR')
        plt.ylim(0.0, 1.0)
        plt.title('TPR @ 1% FPR')
        plt.legend()
        plt.grid(True)
        plt.savefig(f"./{filename}_tpr_fpr.png")
        plt.close()
        print(f"Results saved to {output_filename}")

    member_loss = []
    nonmember_loss = []

    for i in range(member.shape[0]):
        max_val = max(member[i].max().item(), nonmember[i].max().item(), 1e-8)
        member_norm = member[i] / max_val
        nonmember_norm = nonmember[i] / max_val
        member_loss.append(member_norm)
        nonmember_loss.append(nonmember_norm)

    member_avg_loss = []
    nonmember_avg = []

    for i in range(len(member_loss)):
        member_avg_loss.append(torch.mean(member_loss[i]).item())
        nonmember_avg.append(torch.mean(nonmember_loss[i]).item())

    def plot_loss_frequency(all_member_losses, all_nonmember_losses, filename):
        plt.figure(figsize=(8, 6))
        plt.hist(all_member_losses, bins=50, alpha=0.5, label='Member', color='red')
        plt.hist(all_nonmember_losses, bins=50, alpha=0.5, label='Nonmember', color='gray')
        plt.xlabel('Loss')
        plt.ylabel('Frequency')
        plt.title('Frequency of Loss')
        plt.legend()
        plt.grid(True)
        plt.savefig(f"./{filename}_plot_loss_frequency.png")

    plot_auroc(auroc, f"{attacker_name}_{dataset}")
    plot_tpr_fpr(cp_tpr_fpr_1, f"{attacker_name}_{dataset}")
    plot_loss_frequency(all_member_losses, all_nonmember_losses, f"{attacker_name}_{dataset}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--restore_step", type=int, default=100000)
    parser.add_argument(
        "-p",
        "--preprocess_config",
        type=str,
        required=True,
        help="path to preprocess.yaml",
    )
    parser.add_argument(
        "-m", "--model_config", type=str, required=True, help="path to model.yaml"
    )
    parser.add_argument(
        "-t", "--train_config", type=str, required=True, help="path to train.yaml"
    )
    args = parser.parse_args()

    preprocess_config = yaml.load(open(args.preprocess_config, "r"), Loader=yaml.FullLoader)
    model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader)
    train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader)
    configs = (preprocess_config, model_config, train_config)

    main(args, configs)