import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
from pytorch_lightning import seed_everything
from tqdm import tqdm
from rich.progress import track
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
from itertools import chain
import importlib
import os
from copy import deepcopy
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from data_collate import DistributedBucketSampler
from tools import plot_tensor, save_plot
from model.utils import fix_len_compatibility
import json
import tools
from torchmetrics.classification import AUROC, ROC


class ModelEmaV2(torch.nn.Module):
    def __init__(self, model, decay=0.9999, device=None):
        super(ModelEmaV2, self).__init__()
        if hasattr(model, "module"):
            self.model_state_dict = deepcopy(model.module.state_dict())
        else:
            self.model_state_dict = deepcopy(model.state_dict())
        self.decay = decay
        self.device = device_count

    def _update(self, model, update_fn):
        model_values = model.module.state_dict().values() if hasattr(model, "module") else model.state_dict().values()
        with torch.no_grad():
            for ema_v, model_v in zip(self.model_state_dict.values(), model_values):
                if self.device is not None:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        return self.model_state_dict



def run(rank, n_gpus, hps, ckpt,
         attacker_name="durmi",
         dataset="lj",
         seed=0,
         batch_size=1
         ):
         
    T = 1000
    seed_everything(seed)

    logger_text = tools.get_logger(hps.model_dir)
    logger_text.info(hps)
    out_size = fix_len_compatibility(getattr(hps.data, "cut_segment_seconds", 2) * hps.data.sampling_rate // hps.data.hop_length)

    dist.init_process_group(backend='nccl', init_method='env://', world_size=n_gpus, rank=rank)
    torch.manual_seed(hps.train.seed + rank)
    torch.cuda.set_device(rank)
    np.random.seed(hps.train.seed + rank)

    device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
    log_dir = hps.model_dir

    if rank == 0:
        print('Initializing logger...')
        logger = SummaryWriter(log_dir=log_dir)

    train_dataset, collate_fn, model = tools.get_correct_class(hps, train=True)
    val_dataset, _, _ = tools.get_correct_class(hps, train=False)
    batch_collate = collate_fn

    train_loader = DataLoader(dataset=train_dataset, shuffle=False, pin_memory=True,
                        collate_fn=batch_collate, batch_size=batch_size,
                        num_workers=4)
    val_loader = DataLoader(dataset=val_dataset, shuffle=False, pin_memory=True,
                        collate_fn=batch_collate, batch_size=batch_size,
                        num_workers=4)

    model = model(**hps.model).to(device)

    tools.load_checkpoint(ckpt, model, None)
    print(f"Loaded checkpoint from {ckpt}")

    use_gt_dur = getattr(hps.train, "use_gt_dur", False)
    if use_gt_dur:
        print("++++++++++++++> Using ground truth duration for attack")


    def recon_score(batch):
        x, x_lengths = batch['text_padded'].to(device), batch['input_lengths'].to(device)
        y, y_lengths = batch['mel_padded'].to(device), batch['output_lengths'].to(device)

        out_size = fix_len_compatibility(getattr(hps.data, "cut_segment_seconds", 2) * hps.data.sampling_rate // hps.data.hop_length)
        noise = batch['noise_padded']
        if noise is not None:
            noise = noise.to(device)

        if hps.xvector:
            spk = batch['xvector'].to(device)
        else:
            spk = batch['spk_ids'].to(torch.long).to(device)

        dur_loss = model.forward_decoder_ahead_duration(x, x_lengths,
                                                         y, y_lengths, 
                                                         noise=noise,
                                                         spk=spk,
                                                         out_size=out_size,
                                                         use_gt_dur=use_gt_dur, 
                                                         durs=batch['dur_padded'].to(device) if use_gt_dur else None)
        
        return dur_loss
    
    print("attack start...")
    member_dur_losses, nonmember_dur_losses = [], []
    count = 0
    
    for member, nonmember in track(zip(train_loader, chain(*([val_loader]))), total=len(val_loader)):

        dur_loss_m = recon_score(member)
        dur_loss_nm = recon_score(nonmember)

        member_dur_losses.append(dur_loss_m.unsqueeze(0))
        nonmember_dur_losses.append(dur_loss_nm.unsqueeze(0))

        count += 1


    member = torch.cat(member_dur_losses, dim=0)
    nonmember = torch.cat(nonmember_dur_losses, dim=0)
    
    all_scores = torch.cat([member, nonmember], dim=0)
    all_labels = torch.cat([
        torch.zeros(member.size(0), dtype=torch.long),
        torch.ones(nonmember.size(0), dtype=torch.long)
    ], dim=0).cuda()

    auroc_val = AUROC(task="BINARY").cuda()(all_scores, all_labels).item()

    fpr, tpr, thresholds = ROC(task="BINARY").cuda()(all_scores, all_labels)
    
    fpr_np = fpr.detach().cpu().numpy()
    tpr_np = tpr.detach().cpu().numpy()
    idx = np.searchsorted(fpr_np, 0.01, side="right")
    if idx == 0:
        tpr_at_1fpr = tpr_np[0]
    else:
        tpr_at_1fpr = tpr_np[idx - 1] 

    results = {
        "count": int(member.size(0)),
        "member_count": len(member_dur_losses),

        "auroc": auroc_val,
        "tpr_at_1fpr": float(tpr_at_1fpr),

        "fpr": fpr.tolist(),     
        "tpr": tpr.tolist(),     
        "thresholds": thresholds.tolist(), 

        "member_dur_losses": member.tolist(),      
        "nonmember_dur_losses": nonmember.tolist(),
    }
 

    output_filename = f"./vf_gradtts_{attacker_name}_{dataset}.json"


    with open(output_filename, 'w') as json_file:
        json.dump(results, json_file, indent=4)

    plt.figure(figsize=(8, 6))
    dur_mem = [t.item() for t in member_dur_losses]
    dur_nonmem = [t.item() for t in nonmember_dur_losses]
    plt.hist(dur_mem, bins=50, alpha=0.5, label='Member dataset', color='red')
    plt.hist(dur_nonmem, bins=50, alpha=0.5, label='Hold-out dataset', color='gray')
    plt.xlabel("Duration Loss")
    plt.ylabel("Frequency")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(f"./vf_{attacker_name}_{dataset}_histogram.png")
    plt.close()

    
   
if __name__ == "__main__":
    assert torch.cuda.is_available(), "CPU training is not allowed."
    n_gpus = torch.cuda.device_count()
    print(f'============> using {n_gpus} GPUS')
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '8001'

    hps, args = tools.get_hparams_decode()
    print(hps.model_dir)
    ckpt = tools.latest_checkpoint_path(hps.model_dir, "grad_*.pt" if not args.EMA else "EMA_grad_*.pt")
    mp.spawn(run, nprocs=n_gpus, args=(n_gpus, hps, ckpt))