import numpy as np
import os
import json
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
import seaborn as sns

import torch
from rich.progress import track
import fire
import logging
from rich.logging import RichHandler
from pytorch_lightning import seed_everything
from typing import Type, Dict
from grad_tts.model import GradTTS
from itertools import chain
import importlib
from grad_tts.text.symbols import symbols
from grad_tts.data import TextMelSpeakerDataset, TextMelSpeakerBatchCollate
from torch.utils.data import DataLoader
from grad_tts.model.utils import fix_len_compatibility
from grad_tts.data import TextMelDataset, TextMelBatchCollate
from sklearn.metrics import roc_curve, auc
from torchmetrics.classification import AUROC, ROC


params_dict = {
    'ljspeech': lambda: importlib.import_module('grad_tts.params_ljspeech'),
    'libritts': lambda: importlib.import_module('grad_tts.params_libritts'),
    'vctk': lambda: importlib.import_module('grad_tts.params_vctk'),
}

DEVICE = 'cuda'

@torch.no_grad()
def main(checkpoint,
         dataset,
         attacker_name="durmi",
         seed=0,
         batch_size=1):
    T = 1000
    seed_everything(seed)

    params = params_dict[dataset]()
    train_filelist_path = params.train_filelist_path
    valid_filelist_path = params.valid_filelist_path
    cmudict_path = params.cmudict_path
    add_blank = params.add_blank

    nsymbols = len(symbols) + 1 if add_blank else len(symbols)
    n_enc_channels = params.n_enc_channels
    filter_channels = params.filter_channels
    filter_channels_dp = params.filter_channels_dp
    n_enc_layers = params.n_enc_layers
    enc_kernel = params.enc_kernel
    enc_dropout = params.enc_dropout
    spk_emb_dim = params.spk_emb_dim
    n_heads = params.n_heads
    window_size = params.window_size

    n_feats = params.n_feats
    n_fft = params.n_fft
    sample_rate = params.sample_rate
    hop_length = params.hop_length
    win_length = params.win_length
    f_min = params.f_min
    f_max = params.f_max
    n_spks = params.n_spks

    dec_dim = params.dec_dim
    beta_min = params.beta_min
    beta_max = params.beta_max
    pe_scale = params.pe_scale
    output_size = fix_len_compatibility(2 * 22050 // 256)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    logger.addHandler(RichHandler())

    logger.info("initializing model...")
    model = GradTTS(nsymbols, n_spks, None if n_spks == 1 else spk_emb_dim, n_enc_channels,
                    filter_channels, filter_channels_dp,
                    n_heads, n_enc_layers, enc_kernel, enc_dropout, window_size,
                    n_feats, dec_dim, beta_min, beta_max, pe_scale).cuda()
    logger.info("loading checkpoint...")
    if 'libritts' in dataset:
        model.load_state_dict(torch.load(checkpoint, map_location=lambda loc, storage: loc))
    else:
        model.load_state_dict(torch.load(checkpoint, map_location=lambda loc, storage: loc))
    model.eval()

    logger.info("loading dataset...")
    if n_spks > 1:
        train_dataset = TextMelSpeakerDataset(train_filelist_path, cmudict_path, add_blank,
                                              n_fft, n_feats, sample_rate, hop_length,
                                              win_length, f_min, f_max)
        batch_collate = TextMelSpeakerBatchCollate()
        train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
                                  collate_fn=batch_collate, drop_last=True,
                                  shuffle=False)

        test_dataset = TextMelSpeakerDataset(valid_filelist_path, cmudict_path, add_blank,
                                             n_fft, n_feats, sample_rate, hop_length,
                                             win_length, f_min, f_max)
        test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
                                 collate_fn=batch_collate, drop_last=True,
                                 shuffle=False)
    else:
        train_dataset = TextMelDataset(train_filelist_path, cmudict_path, add_blank,
                                       n_fft, n_feats, sample_rate, hop_length,
                                       win_length, f_min, f_max)
        batch_collate = TextMelBatchCollate()
        train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
                                  collate_fn=batch_collate, drop_last=True,
                                  shuffle=False)
        test_dataset = TextMelDataset(valid_filelist_path, cmudict_path, add_blank,
                                      n_fft, n_feats, sample_rate, hop_length,
                                      win_length, f_min, f_max)
        test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
                                 collate_fn=batch_collate, drop_last=True,
                                 shuffle=False)


    def recon_score(batch):
        x, x_lengths = batch['x'].cuda(), batch['x_lengths'].cuda()
        y, y_lengths = batch['y'].cuda(), batch['y_lengths'].cuda()
        spk = None if 'spk' not in batch else batch['spk'].to(torch.long).cuda()

        y_mask, mu_y, y, dur_loss, _ = model.forward_decoder_ahead(x, x_lengths, y, y_lengths, out_size=output_size, spk=spk)

        return dur_loss

    logger.info("attack start...")
    member_dur_losses, nonmember_dur_losses = [], []

    count = 0
    for member, nonmember in track(zip(train_loader, chain(*([test_loader]))), total=len(test_loader)):
        dur_loss_m = recon_score(member)
        dur_loss_nm = recon_score(nonmember)
    
        print("dur_loss_m" , dur_loss_m)
        print("dur_loss_nm" , dur_loss_nm)

        member_dur_losses.append(dur_loss_m.unsqueeze(0))
        nonmember_dur_losses.append(dur_loss_nm.unsqueeze(0))

        count += 1

    member = torch.cat(member_dur_losses, dim=0)
    nonmember = torch.cat(nonmember_dur_losses, dim=0)

    print("member_dur_losses: ", member) 
    print("nonmember_dur_losses: ", nonmember)
    
    all_scores = torch.cat([member, nonmember], dim=0)
    all_labels = torch.cat([
        torch.zeros(member.size(0), dtype=torch.long),
        torch.ones(nonmember.size(0), dtype=torch.long)
    ], dim=0).cuda()

    auroc_val = AUROC().cuda()(all_scores, all_labels).item()

    fpr, tpr, thresholds = ROC().cuda()(all_scores, all_labels)
    fpr_np = fpr.detach().cpu().numpy()
    tpr_np = tpr.detach().cpu().numpy()
    idx = np.searchsorted(fpr_np, 0.01, side="right")
    if idx == 0:
        tpr_at_1fpr = tpr_np[0]
    else:
        tpr_at_1fpr = tpr_np[idx - 1]  
    results = {
        "count": int(member.size(0)),
        "member_count": len(member_dur_losses),

        "auroc": auroc_val,
        "tpr_at_1fpr": float(tpr_at_1fpr),

        "fpr": fpr.tolist(),     
        "tpr": tpr.tolist(),     
        "thresholds": thresholds.tolist(), 

        "member_dur_losses": member.tolist(),      
        "nonmember_dur_losses": nonmember.tolist(),
    }
 
    output_filename = f"./gt_{attacker_name}_{dataset}.json"

    with open(output_filename, 'w') as json_file:
        json.dump(results, json_file, indent=4)

    plt.figure(figsize=(8, 6))
    dur_mem = [t.item() for t in member_dur_losses]
    dur_nonmem = [t.item() for t in nonmember_dur_losses]
    plt.hist(dur_mem, bins=50, alpha=0.5, label='Member dataset', color='red')
    plt.hist(dur_nonmem, bins=50, alpha=0.5, label='Hold-out dataset', color='gray')
    plt.xlabel("Duration Loss")
    plt.ylabel("Frequency")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"./gt_{attacker_name}_{dataset}_histogram.png")
    plt.close()
   
if __name__ == '__main__':
    fire.Fire(main)