import logging
import json
import os
import yaml
import hydra
import torch
from torch.utils.data import DataLoader
from omegaconf import DictConfig, OmegaConf

from lhotse import CutSet, Fbank, FbankConfig
from lhotse.dataset import DynamicBucketingSampler, OnTheFlyFeatures

from auden.auto.auto_model import AutoModel
from auden.data.dataset.audio_caption_dataset import AudioCaptionDataset
from auden.utils.checkpoint import resolve_checkpoint_filename, generate_and_save_averaged_model

from aac_metrics import Evaluate

def get_test_dataloaders(cfg):
    test_dls = []
    test_names = []
    with open(cfg.data.test_data_config, 'r') as file:
        test_data_config = yaml.load(file, Loader=yaml.FullLoader)
        
    def remove_short_utterance(c):
        c.supervisions = [c.supervisions[0]]
        if c.duration < 1.0:
            return False
        return True
        
    for test_set in test_data_config:
        logging.info(f"Getting {test_set['manifest']} cuts")
        cutset = CutSet.from_file(test_set['manifest'])
        cutset = cutset.filter(remove_short_utterance)
        test_name = test_set['name']
        testset = AudioCaptionDataset(
            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
            return_cuts=True,
        )
        sampler = DynamicBucketingSampler(
            cutset,
            max_duration=cfg.data.max_duration,
            shuffle=False,
        )
        test_dl = DataLoader(
            testset,
            batch_size=None,
            sampler=sampler,
            num_workers=cfg.data.num_workers,
        )
        test_dls.append(test_dl)
        test_names.append(test_name)
    return test_names, test_dls

@hydra.main(version_base=None, config_path="configs", config_name="evaluate")
@torch.no_grad()
def main(cfg: DictConfig):
    logging.info("\n" + OmegaConf.to_yaml(cfg))

    # initialize dataloader
    test_sets, test_dls = get_test_dataloaders(cfg)

    # Initialize model
    checkpoint_filename = resolve_checkpoint_filename(
        checkpoint_filename=cfg.checkpoint.get("pretrained_model", None),
        epoch=cfg.checkpoint.get("epoch", 0),
        iter=cfg.checkpoint.get("iter", 0),
        avg=cfg.checkpoint.get("avg", 1),
    )
    if checkpoint_filename.startswith('averaged'):
        generate_and_save_averaged_model(cfg.exp_dir, 
                                         epoch=cfg.checkpoint.get("epoch", 0),
                                         iter=cfg.checkpoint.get("iter", 0),
                                         avg=cfg.checkpoint.get("avg", 1))
    model = AutoModel.from_pretrained(
        exp_dir=cfg.exp_dir,
        checkpoint_filename=checkpoint_filename,
    )

    device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)
    model.eval()
    num_param = sum(p.numel() for p in model.parameters())
    logging.info(f"Number of model parameters: {num_param}")
    
    evaluator = Evaluate(metrics=["bleu_4", "meteor", "rouge_l", "spider"])
        
    for test_set_name, test_dl in zip(test_sets, test_dls):
        num_cuts = 0
        try:
            num_batches = len(test_dl)
        except TypeError:
            num_batches = "?"
            
        pred_str_all, ref_str_all = [], []

        for batch_idx, batch in enumerate(test_dl):
            cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
            num_cuts += len(cut_ids)
            
            feature = batch["inputs"].to(device)
            # at entry, feature is (N, T, C)
            supervisions = batch["supervisions"]
            ref_text = supervisions["audio_caption"]  
            feature_lens = supervisions["num_frames"].to(device)

            _, output_text = model(
                x=feature,
                x_lens=feature_lens,
                text=ref_text,
            )

            pred_str_all.extend(output_text)
            ref_str_all.extend(ref_text)

            print("ref_text:", ref_text[0])
            print("output_text:", output_text[0])

            if batch_idx % 20 == 1:
                logging.info(f"Processed {num_cuts} cuts already.")
        logging.info("Finish collecting audio logits")

        assert len(pred_str_all) == len(ref_str_all), "Mismatch in number of predictions and references"

        try:
            corpus_scores, _ = evaluator(pred_str_all, ref_str_all)
        except Exception as e:
            logging.error(f"Error computing metrics: {e}")

        logging.info(f"Evaluation results for {test_set_name}:")
        logging.info(f"BLEU-4: {corpus_scores['bleu_4']}")
        logging.info(f"Meteor: {corpus_scores['meteor']}")
        logging.info(f"ROUGE-L: {corpus_scores['rouge_l']}")
        logging.info(f"SPIDER: {corpus_scores['spider']}")

        logging.info("Done")


if __name__ == "__main__":
    main()