import logging
import json
import os
import yaml
import hydra
import torch
from torch.utils.data import DataLoader
from omegaconf import DictConfig, OmegaConf

from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse.dataset import DynamicBucketingSampler, OnTheFlyFeatures
from auden.auto.auto_model import AutoModel
from auden.models.audio_tag.utils import compute_acc
from auden.data.dataset.audio_dataset import AudioDataset
from auden.utils.checkpoint import resolve_checkpoint_filename, generate_and_save_averaged_model
from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_recall_curve
import numpy as np

def get_test_dataloaders(cfg):
    test_dls = []
    test_names = []
    with open(cfg.data.test_data_config, 'r') as file:
        test_data_config = yaml.load(file, Loader=yaml.FullLoader)

    def remove_short_utterance(cut):
        return cut.duration >= 1.0

    for test_set in test_data_config:
        logging.info(f"Getting {test_set['manifest']} cuts")
        cutset = load_manifest_lazy(test_set['manifest']).resample(16000)
        cutset = cutset.filter(remove_short_utterance)
        test_name = test_set['name']
        testset = AudioDataset(
            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
            return_cuts=True,
        )
        sampler = DynamicBucketingSampler(
            cutset,
            max_duration=cfg.data.max_duration,
            shuffle=False,
            num_buckets=30,
            buffer_size=30 * 2000,
        )
        test_dl = DataLoader(
            testset,
            batch_size=None,
            sampler=sampler,
            num_workers=cfg.data.num_workers,
        )
        test_dls.append(test_dl)
        test_names.append(test_name)
    return test_names, test_dls

def load_thresholds(thresholds_file):
    if os.path.exists(thresholds_file):
        with open(thresholds_file, 'r') as f:
            thresholds = json.load(f)
    
    threshold_list = []
    for c in range(len(thresholds)):
        if str(c) in thresholds:
            threshold_list.append(float(thresholds[str(c)]))
        else:
            raise ValueError(f"Threshold for class {c} not found in {thresholds_file}")
    return threshold_list

@hydra.main(version_base=None, config_path='configs', config_name="evaluate")
@torch.no_grad()
def main(cfg: DictConfig):
    logging.info("\n" + OmegaConf.to_yaml(cfg))

    # Get dataloaders
    test_sets, test_dls = get_test_dataloaders(cfg)

    # Initialize model
    checkpoint_filename = cfg.checkpoint.filename
    model = AutoModel.from_pretrained(
        exp_dir=cfg.exp_dir,
        checkpoint_filename=checkpoint_filename,
    )

    device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)
    model.eval()
    num_param = sum(p.numel() for p in model.parameters())
    logging.info(f"Number of model parameters: {num_param}")

    # Load thresholds if provided
    thresholds = load_thresholds("/apdcephfs_cq12/share_302080740/user/raytseng/data/ckpt/best_thresholds_audioset-eval.json")
    thresholds = torch.tensor(thresholds, dtype=torch.float32).to(device)

    # do evaluation for the dataset
    for test_set_name, test_dl in zip(test_sets, test_dls):
        num_cuts = 0
        # try:
        #     num_batches = len(test_dl)
        # except TypeError:
        #     num_batches = "?"

        with open(f"/apdcephfs_cq12/share_302080740/user/raytseng/data/CaptionStew/labels/{test_set_name}_predictions.tsv", 'w') as f:
            for batch_idx, batch in enumerate(test_dl):
                cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
                num_cuts += len(cut_ids)

                feature = batch["inputs"].to(device)
                feature_lens = batch["supervisions"]["num_frames"].to(device)

                audio_logits = model.generate((feature, feature_lens), return_full_logits=True)

                audio_probs = torch.sigmoid(audio_logits)

                audio_labels = (audio_probs >= thresholds).to(torch.bool)    # [B, C] bool on GPU
                audio_labels = audio_labels.cpu().numpy().astype(np.int8)  # → NumPy for I/O

                for i, cut_id in enumerate(cut_ids):
                    active_idx = np.where(audio_labels[i] == 1)[0]         # indices of “present” classes
                    f.write(f"{cut_id}\t{';'.join(map(str, active_idx))}\n")

                if batch_idx % 20 == 1:
                    logging.info(f"Processed {num_cuts} cuts already.")
                    print(f"{cut_id}\t{';'.join(map(str, active_idx))}")
            logging.info("Finish collecting audio logits")



if __name__ == "__main__":
    main()
