import logging
import os
import yaml
import torch
import numpy as np
from torch.nn import functional as F
from torch.utils.data import DataLoader
from omegaconf import DictConfig, OmegaConf
import hydra
from pathlib import Path
from tqdm import tqdm

from lhotse import CutSet, Fbank, FbankConfig
from lhotse.dataset import DynamicBucketingSampler, OnTheFlyFeatures, PrecomputedFeatures
from auden.auto.auto_model import AutoModel
from auden.data.dataset.audio_tag_dataset import AudioTaggingDataset
from auden.utils.checkpoint import resolve_checkpoint_filename, generate_and_save_averaged_model

from auden.models.zipformer.utils.padding import make_pad_mask
from auden.models.spk_veri.utils import compute_verification_metrics

def load_verification_pairs(pair_txt_path):
    """
    Load speaker verification pairs from a text file.

    Each line in the file should contain three whitespace-separated fields:
        <label> <utt_path_1> <utt_path_2>

    - <label>: 1 if the two utterances are from the same speaker, 0 otherwise
    - <utt_path_1>: the cut ID or relative path to the first utterance
    - <utt_path_2>: the cut ID or relative path to the second utterance

    Example:
        1 id10270/x6uYqmx31kE/00001.wav id10270/8jEAjG6SegY/00008.wav
        0 id10270/x6uYqmx31kE/00001.wav id10300/ize_eiCFEg0/00003.wav

    Returns:
        List of (label: int, utt1: str, utt2: str)
    """
    pairs = []
    with open(pair_txt_path, 'r') as f:
        for line in f:
            label, utt1, utt2 = line.strip().split()
            utt1 = utt1.replace("/", "__").replace(".wav", "")
            utt2 = utt2.replace("/", "__").replace(".wav", "")
            pairs.append((int(label), utt1, utt2))
    return pairs


def get_test_dataloader(cfg):
    test_dls = []
    test_names = []    
    test_pair_txts = []  

    with open(cfg.data.test_data_config, 'r') as file:
        test_data_config = yaml.load(file, Loader=yaml.FullLoader)

    def remove_short_utterance(cut):
        return cut.duration >= 1.0

    cutset_cache = {}
    dataloader_cache = {}
    
    for test_set in test_data_config:
        manifest_path = test_set['manifest']
        test_name = test_set['name']
        pair_txt = test_set['pair_txt']


        if manifest_path not in cutset_cache:

            logging.info(f"Getting {test_set['manifest']} cuts")
            cutset = CutSet.from_file(test_set['manifest'])
            cutset = cutset.filter(remove_short_utterance)
            cutset_cache[manifest_path] = cutset

            testset = AudioTaggingDataset(
                # input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), 
                input_strategy=PrecomputedFeatures(), # To save a lot of time
                return_cuts=True,
                label_field="speaker",
            )
            sampler = DynamicBucketingSampler(
                cutset,
                max_duration=cfg.data.max_duration,
                shuffle=False,
            )
            test_dl = DataLoader(
                testset,
                batch_size=None,
                sampler=sampler,
                num_workers=cfg.data.num_workers,
            )
            dataloader_cache[manifest_path] = test_dl
        else:
            logging.info(f"Reusing cached dataloader for {manifest_path}")

        test_dls.append(dataloader_cache[manifest_path])
        test_names.append(test_name)
        test_pair_txts.append(pair_txt)

    return test_names, test_dls, test_pair_txts



@hydra.main(version_base=None, config_path='configs', config_name="evaluate")
@torch.no_grad()
def main(cfg: DictConfig):
    logging.info("\n" + OmegaConf.to_yaml(cfg))

    # 1. Load dataloader and cutset
    test_sets, test_dls, test_pair_txts = get_test_dataloader(cfg)

    # 2. Load model
    checkpoint_filename = resolve_checkpoint_filename(
        checkpoint_filename=cfg.checkpoint.get("pretrained_model", None),
        epoch=cfg.checkpoint.get("epoch", 0),
        iter=cfg.checkpoint.get("iter", 0),
        avg=cfg.checkpoint.get("avg", 1),
    )

    ckpt_path = Path(cfg.exp_dir) / checkpoint_filename
    if not ckpt_path.exists() and checkpoint_filename.startswith("averaged"):
        generate_and_save_averaged_model(
            exp_dir=cfg.exp_dir,
            epoch=cfg.checkpoint.get("epoch", 0),
            iter=cfg.checkpoint.get("iter", 0),
            avg=cfg.checkpoint.get("avg", 1)
        )

    model = AutoModel.from_pretrained(
        exp_dir=cfg.exp_dir,
        checkpoint_filename=checkpoint_filename,
    )


    device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)
    model.eval()
    num_param = sum(p.numel() for p in model.parameters())
    logging.info(f"Number of model parameters: {num_param}")

    embedding_cache = {}

    for test_set_name, test_dl, pair_txt in zip(test_sets, test_dls, test_pair_txts):
        dl_id = id(test_dl)
        
        # 3. Generate embeddings
        if dl_id not in embedding_cache:
            print(f"[{test_set_name}] Extracting embeddings...")            

            embedding_dict = {}
            for batch in tqdm(test_dl):
                cuts = batch["supervisions"]["cut"]
                cut_ids = [cut.id for cut in cuts]
                features = batch["inputs"].to(device)
                feature_lens = batch["supervisions"]["num_frames"].to(device)

                # Same for all ZipformerEncoderModel: forward encoder + pooling
                encoder_output = model.forward_encoder(features, feature_lens)
                if model.encoder_fusion_weights is not None:
                    fusion_weights = F.softmax(model.encoder_fusion_weights, dim=0).view(-1, 1, 1, 1)
                    encoder_out = (encoder_output.encoder_out_full * fusion_weights).sum(dim=0)
                else:
                    encoder_out = encoder_output.encoder_out

                padding_mask = make_pad_mask(encoder_output.encoder_out_lens)  # (N, T)
                encoder_out[padding_mask] = 0  
                valid_counts = (~padding_mask).sum(dim=1, keepdim=True)  # (N, 1)
                pooled_embeddings = encoder_out.sum(dim=1) / valid_counts  # (N, D) 
                
                for cut_id, emb in zip(cut_ids, pooled_embeddings):
                    # cut_id has suffix -xxx, remove to match the pair ids
                    base_id = cut_id.rsplit("-", 1)[0]  
                    embedding_dict[base_id] = emb.detach().cpu()
            logging.info(f"Extracted embeddings for {len(embedding_dict)} utterances")
            
            embedding_cache[dl_id] = embedding_dict
        else:
            print(f"[{test_set_name}] Reusing cached embeddings.")
            embedding_dict = embedding_cache[dl_id]


        # 4. Load pairs
        pairs = load_verification_pairs(pair_txt)
        logging.info(f"Loaded {len(pairs)} pairs from {pair_txt}")

        # 5. Compute cosine similarities
        scores, labels = [], []
        for label, utt1, utt2 in pairs:
            if utt1 not in embedding_dict or utt2 not in embedding_dict:
                logging.warning(f"Missing embedding for pair: {utt1}, {utt2}, skipping.")
                continue
            emb1 = embedding_dict[utt1]
            emb2 = embedding_dict[utt2]
            sim = F.cosine_similarity(emb1, emb2, dim=0).item()

            scores.append(sim)
            labels.append(label)

        # 6. Metrics
        auc, eer, threshold, min_dcf = compute_verification_metrics(scores, labels)

        logging.info(
            f"Test set {test_set_name}: AUC: {auc:.4f}, EER: {eer:.4f}, "
            f"Threshold@EER: {threshold:.4f}, minDCF: {min_dcf:.4f}"
        )

if __name__ == "__main__":
    main()
