import numpy as np
import os
import json
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from sklearn import metrics

import time
import torch
from rich.progress import track
import fire
import logging
from rich.logging import RichHandler
from pytorch_lightning import seed_everything
from typing import Type, Dict
from grad_tts.model import GradTTS
from itertools import chain
import importlib
from grad_tts.text.symbols import symbols
from grad_tts.data import TextMelSpeakerDataset, TextMelSpeakerBatchCollate
from torch.utils.data import DataLoader
from grad_tts.model.utils import fix_len_compatibility
from grad_tts.data import TextMelDataset, TextMelBatchCollate
from torchmetrics.classification import AUROC, ROC


params_dict = {
    'ljspeech': lambda: importlib.import_module('grad_tts.params_ljspeech'),
    'libritts': lambda: importlib.import_module('grad_tts.params_libritts'),
    'vctk': lambda: importlib.import_module('grad_tts.params_vctk'),
}


DEVICE = 'cuda'


@torch.no_grad()
def main(checkpoint,
         dataset,
         attacker_name="PIA",
         attack_num=100, interval=10,
         seed=0,
         batch_size=1):
    T = 1000
    seed_everything(seed)

    params = params_dict[dataset]()
    train_filelist_path = params.train_filelist_path
    valid_filelist_path = params.valid_filelist_path
    cmudict_path = params.cmudict_path
    add_blank = params.add_blank

    nsymbols = len(symbols) + 1 if add_blank else len(symbols)
    n_enc_channels = params.n_enc_channels
    filter_channels = params.filter_channels
    filter_channels_dp = params.filter_channels_dp
    n_enc_layers = params.n_enc_layers
    enc_kernel = params.enc_kernel
    enc_dropout = params.enc_dropout
    spk_emb_dim = params.spk_emb_dim
    n_heads = params.n_heads
    window_size = params.window_size

    n_feats = params.n_feats
    n_fft = params.n_fft
    sample_rate = params.sample_rate
    hop_length = params.hop_length
    win_length = params.win_length
    f_min = params.f_min
    f_max = params.f_max
    n_spks = params.n_spks

    dec_dim = params.dec_dim
    beta_min = params.beta_min
    beta_max = params.beta_max
    pe_scale = params.pe_scale
    output_size = fix_len_compatibility(2 * 22050 // 256)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    logger.addHandler(RichHandler())

    logger.info("initializing model...")
    model = GradTTS(nsymbols, n_spks, None if n_spks == 1 else spk_emb_dim, n_enc_channels,
                    filter_channels, filter_channels_dp,
                    n_heads, n_enc_layers, enc_kernel, enc_dropout, window_size,
                    n_feats, dec_dim, beta_min, beta_max, pe_scale).cuda()
    logger.info("loading checkpoint...")
    if 'libritts' in dataset:
        model.load_state_dict(torch.load(checkpoint, map_location=lambda loc, storage: loc)['ckpt'])
    else:
        model.load_state_dict(torch.load(checkpoint, map_location=lambda loc, storage: loc))
    model.eval()

    logger.info("loading dataset...")
    if n_spks > 1:
        train_dataset = TextMelSpeakerDataset(train_filelist_path, cmudict_path, add_blank,
                                              n_fft, n_feats, sample_rate, hop_length,
                                              win_length, f_min, f_max)
        batch_collate = TextMelSpeakerBatchCollate()
        train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
                                  collate_fn=batch_collate, drop_last=True,
                                  shuffle=False)

        test_dataset = TextMelSpeakerDataset(valid_filelist_path, cmudict_path, add_blank,
                                             n_fft, n_feats, sample_rate, hop_length,
                                             win_length, f_min, f_max)
        test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
                                 collate_fn=batch_collate, drop_last=True,
                                 shuffle=False)
    else:
        train_dataset = TextMelDataset(train_filelist_path, cmudict_path, add_blank,
                                       n_fft, n_feats, sample_rate, hop_length,
                                       win_length, f_min, f_max)
        batch_collate = TextMelBatchCollate()
        train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size,
                                  collate_fn=batch_collate, drop_last=True,
                                  shuffle=False)
        test_dataset = TextMelDataset(valid_filelist_path, cmudict_path, add_blank,
                                      n_fft, n_feats, sample_rate, hop_length,
                                      win_length, f_min, f_max)
        test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size,
                                 collate_fn=batch_collate, drop_last=True,
                                 shuffle=False)


    def recon_score(n_timesteps, terminal_time, batch):
        x, x_lengths = batch['x'].cuda(), batch['x_lengths'].cuda()
        y, y_lengths = batch['y'].cuda(), batch['y_lengths'].cuda()
        spk = None if 'spk' not in batch else batch['spk'].to(torch.long).cuda()

        start_time = time.perf_counter()
        attacker = {
            'SecMI': model.decoder.SecMI,
            'naive': model.decoder.naive_attack,
            'PIA': model.decoder.PIA,
            'PIAN': model.decoder.PIAN
        }
        
        y_mask, mu_y, y, _ = model.forward_decoder_ahead(x, x_lengths, y, y_lengths, out_size=output_size, spk=spk)
        if hasattr(model, "spk_emb"):
            spk = model.spk_emb(spk)

        result = attacker[attacker_name](y, mu_y, y_mask, n_timesteps, 0, 0, terminal_time=terminal_time, spk=spk)
            
        end_time = time.perf_counter()
        inference_time = end_time - start_time
        print(f"Inference time: {inference_time:.4f} seconds")
        
        return result

    logger.info("attack start...")
    members, nonmembers = [], []
    count = 0
    for member, nonmember in track(zip(train_loader, chain(*([test_loader]))), total=len(test_loader)):
        members.append(recon_score(attack_num, interval / T * attack_num, member))
        nonmembers.append(recon_score(attack_num, interval / T * attack_num, nonmember))

        members = [torch.cat(members, dim=-1)]
        nonmembers = [torch.cat(nonmembers, dim=-1)]

        count += 1
        if count == 10:
            break

    member = members[0]
    nonmember = nonmembers[0]

    auroc = [
        AUROC().cuda()(
            torch.cat([
                member[i] / max([member[i].max().item(), nonmember[i].max().item()]),
                nonmember[i] / max([member[i].max().item(), nonmember[i].max().item()])
            ]), 
            torch.cat([
                torch.zeros(member.shape[1]).long(), 
                torch.ones(nonmember.shape[1]).long()
            ]).cuda()
        ).item()
        for i in range(member.shape[0])
    ]

    tpr_fpr = [
        ROC().cuda()(
            torch.cat([
                1 - nonmember[i] / max([member[i].max().item(), nonmember[i].max().item()]),
                1 - member[i] / max([member[i].max().item(), nonmember[i].max().item()])
            ]),
            torch.cat([
                torch.zeros(member.shape[1]).long(),
                torch.ones(nonmember.shape[1]).long()
            ]).cuda()
        )
        for i in range(member.shape[0])
    ]
    
    cp_auroc = auroc[:]
    cp_auroc.sort(reverse=True)

    tpr_fpr_1 = [i[1][(i[0] < 0.01).sum() - 1].item() for i in tpr_fpr]
    cp_tpr_fpr_1 = tpr_fpr_1[:]
    cp_tpr_fpr_1.sort(reverse=True)

    avg_auroc = np.mean(auroc)
    max_auroc = np.max(auroc)

    avg_tpr_fpr_1 = np.mean(cp_tpr_fpr_1)
    max_tpr_fpr_1 = np.max(cp_tpr_fpr_1)
    
    fprs = [i[0].cpu().numpy() for i in tpr_fpr]
    tprs = [i[1].cpu().numpy() for i in tpr_fpr]
    
    min_len = min(arr.shape[0] for arr in fprs)
    fprs_trimmed = [arr[:min_len] for arr in fprs]
    avg_fprs = np.mean(fprs_trimmed, axis=0)

    min_len2 = min(arr.shape[0] for arr in tprs)
    tprs_trimmed = [arr[:min_len2] for arr in tprs]
    avg_tprs = np.mean(tprs_trimmed, axis=0)

    output_filename = f"./{attacker_name}_{dataset}.json"

    results = {
        "count": count * batch_size,
        "member": member.shape[1],
        "nonmember": nonmember.shape[1],

        "avg_auroc": avg_auroc,
        "max_auroc": max_auroc,
        "avg_tpr_fpr_1": avg_tpr_fpr_1,
        "max_tpr_fpr_1": max_tpr_fpr_1,
        
        "auc": auroc,
        "tpr_fpr_1": cp_tpr_fpr_1,

        "loss_member": member.tolist(), 
        "loss_nonmember": nonmember.tolist(),
        
        "fprs": [fpr.tolist() for fpr in fprs],
        "tprs": [tpr.tolist() for tpr in tprs]
    }

    with open(output_filename, 'w') as json_file:
        json.dump(results, json_file, indent=4)

    def plot_auroc(auroc, filename):
        plt.figure(figsize=(8, 6))
        plt.plot(range(len(auroc)), auroc, label='AUROC')
        plt.xlabel('Timestep')
        plt.ylabel('AUROC Value')
        plt.ylim(0.0, 1.0)
        plt.title('AUROC Curve')
        plt.legend()
        plt.grid(True)
        plt.savefig(f"./{filename}_auroc.png")
        plt.close()

    def plot_tpr_fpr(tpr_fpr_1, filename):
        plt.figure(figsize=(8, 6))
        plt.plot(range(len(tpr_fpr_1)), tpr_fpr_1, label='TPR @ 1% FPR')
        plt.xlabel('Timestep')
        plt.ylabel('TPR @ 1% FPR')
        plt.ylim(0.0, 1.0)
        plt.title('TPR @ 1% FPR')
        plt.legend()
        plt.grid(True)
        plt.savefig(f"./{filename}_tpr_fpr.png")
        plt.close()
        print(f"Results saved to {output_filename}")
    
    member_loss = []
    nonmember_loss = []

    for i in range(member.shape[0]):
        max_val = max(member[i].max().item(), nonmember[i].max().item())
        member_norm = member[i] / max_val
        nonmember_norm = nonmember[i] / max_val
        member_loss.append(member_norm)
        nonmember_loss.append(nonmember_norm)

    member_avg_loss = []
    nonmember_avg = []

    for i in range(len(member_loss)):
        member_avg_loss.append(torch.mean(member_loss[i]).item())
        nonmember_avg.append(torch.mean(nonmember_loss[i]).item())

    all_member_losses = torch.cat(member_loss, dim=0).cpu().numpy().flatten()
    all_nonmember_losses = torch.cat(nonmember_loss, dim=0).cpu().numpy().flatten()

    def plot_loss_frequency(member_avg_loss, nonmember_avg, filename):
        plt.figure(figsize=(8, 6))
        plt.hist(all_member_losses, bins=50, alpha=0.5, label='Member', color='red')
        plt.hist(all_nonmember_losses, bins=50, alpha=0.5, label='Nonmember', color='gray')
        plt.xlabel('Loss')
        plt.ylabel('Frequency')
        plt.title('Frequency of Loss')
        plt.legend()
        plt.grid(True)
        plt.savefig(f"./{filename}_plot_loss_frequency.png")

    plot_auroc(auroc, f"{attacker_name}_{dataset}")
    plot_tpr_fpr(cp_tpr_fpr_1, f"{attacker_name}_{dataset}")
    plot_loss_frequency(member_avg_loss, nonmember_avg, f"{attacker_name}_{dataset}")

if __name__ == '__main__':
    fire.Fire(main)