import logging
import os
import yaml
from pathlib import Path
from collections import defaultdict
import hydra
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader
from omegaconf import DictConfig, OmegaConf

from lhotse import CutSet, Fbank, FbankConfig
from lhotse.dataset import (  # noqa F401 for PrecomputedFeatures
    DynamicBucketingSampler,
    K2SpeechRecognitionDataset,
    OnTheFlyFeatures
)
from auden.auto.auto_model import AutoModel
from auden.utils.text_normalization import text_normalization
from auden.utils.icefall_asr_decode import save_results
from auden.utils.checkpoint import resolve_checkpoint_filename, generate_and_save_averaged_model

def get_test_dataloaders(cfg):
    test_dls = []
    test_names = []
    with open(cfg.data.test_data_config, 'r') as file:
        test_data_config = yaml.load(file, Loader=yaml.FullLoader)
        
    for test_set in test_data_config:
        logging.info(f"Getting {test_set['manifest']} cuts")
        cutset = CutSet.from_file(test_set['manifest'])
        test_name = test_set['name']
        testset = K2SpeechRecognitionDataset(
            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
            return_cuts=True,
        )
        sampler = DynamicBucketingSampler(
            cutset,
            max_duration=cfg.data.max_duration,
            shuffle=False,
        )
        test_dl = DataLoader(
            testset,
            batch_size=None,
            sampler=sampler,
            num_workers=cfg.data.num_workers,
        )
        test_dls.append(test_dl)
        test_names.append(test_name)
    return test_names, test_dls


@hydra.main(version_base=None, config_path="configs", config_name="evaluate")
@torch.no_grad()
def main(cfg: DictConfig):
    logging.info("\n" + OmegaConf.to_yaml(cfg))

    # initialize dataloader
    test_sets, test_dls = get_test_dataloaders(cfg)
    # Initialize model
    checkpoint_filename = resolve_checkpoint_filename(
        checkpoint_filename=cfg.checkpoint.get("pretrained_model", None),
        epoch=cfg.checkpoint.get("epoch", 0),
        iter=cfg.checkpoint.get("iter", 0),
        avg=cfg.checkpoint.get("avg", 1),
    )
    if checkpoint_filename.startswith('averaged'):
        generate_and_save_averaged_model(cfg.exp_dir, 
                                         epoch=cfg.checkpoint.get("epoch", 0),
                                         iter=cfg.checkpoint.get("iter", 0),
                                         avg=cfg.checkpoint.get("avg", 1))
    model = AutoModel.from_pretrained(
        exp_dir=cfg.exp_dir,
        checkpoint_filename=checkpoint_filename,
    )
    device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)
    model.eval()
    num_param = sum(p.numel() for p in model.parameters())
    logging.info(f"Number of model parameters: {num_param}")
    
    # result dir
    res_dir = Path(cfg.exp_dir) / 'greedy_search'
    os.makedirs(res_dir, exist_ok=True)
    if cfg.checkpoint.iter > 0:
        results_file_suffix = f"iter-{cfg.checkpoint.iter}-avg-{cfg.checkpoint.avg}"
    else:
        results_file_suffix = f"epoch-{cfg.checkpoint.epoch}-avg-{cfg.checkpoint.avg}"
        
    for test_set_name, test_dl in zip(test_sets, test_dls):
        num_cuts = 0
        try:
            num_batches = len(test_dl)
        except TypeError:
            num_batches = "?"
        # decoding result
        results = defaultdict(list)
            
        # go through the dataset
        for batch_idx, batch in enumerate(test_dl):
            feature = batch["inputs"]
            feature = feature.to(device)
            # at entry, feature is (N, T, C)
            feature_lens = batch["supervisions"]["num_frames"].to(device)
            hyps = model.generate(input=(feature, feature_lens), decoding_method='greedy_search')
            
            hyps = [text_normalization(hyp, case='lower', 
                                    remove_symbols=True, 
                                    remove_diacritics=True, 
                                    space_between_cjk=True,
                                    simplified_chinese=True,
                                    merge_single_char=True,
                                    remove_erhua=True,
                                    remove_fillers=True).split() for hyp in hyps]
            
            texts = batch["supervisions"]["text"]
            texts = [text_normalization(text, case='lower', 
                                        remove_symbols=True, 
                                        remove_diacritics=True, 
                                        space_between_cjk=True,
                                        simplified_chinese=True,
                                        remove_erhua=True,
                                        remove_fillers=True).split() for text in texts]
            
            cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
            
            this_batch = []
            assert len(hyps) == len(texts)
            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
                this_batch.append((cut_id, ref_text, hyp_words))
            results["greedy_search"].extend(this_batch)
            
            num_cuts += len(texts)
            if batch_idx % 50 == 0:
                batch_str = f"{batch_idx}/{num_batches}"
                logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
                
        save_results(res_dir, test_set_name, results, suffix=results_file_suffix)
    
if __name__ == "__main__":
    main()