import logging
import yaml
import hydra
import torch
from functools import partial
from collections import defaultdict
from torch.utils.data import DataLoader
from omegaconf import DictConfig, OmegaConf

from lhotse import (
    CutSet,
    Fbank,
    FbankConfig,
    set_audio_duration_mismatch_tolerance
)
from lhotse.dataset import (
    DynamicBucketingSampler,
    OnTheFlyFeatures
)
from auden.auto.auto_model import AutoModel
from auden.auto.auto_tokenizer import AutoTokenizer
from auden.utils.checkpoint import (
    resolve_checkpoint_filename,
    generate_and_save_averaged_model
)
from auden.data.dataset.speech2text_dataset import Speech2TextDataset
from auden.models.tta.decode import (
    greedy_search_batch,
    attention_beam_search,
    save_results,
    save_bleu_results
)
from auden.utils.text_normalization import text_normalization

def remove_long_short_utterance(c):
    c.supervisions = [c.supervisions[0]]
    if c.duration < 1.0 or c.duration > 50.0:
        return False
    return True

def unified_language_code(c, lang):
    if lang is not None:
        c.supervisions[0].language = lang
    return c

def unified_sampling_rate(c):
    if c.sampling_rate != 16000:
        c = c.resample(16000)
    return c

def unified_text_normalize(text: str, lang: str = None):
    return text_normalization(
        text, 
        case='lower',
        remove_diacritics=True,
        remove_symbols=False,
        simplified_chinese=True,
        space_between_cjk=True,
    )

def get_test_dataloaders(cfg):
    test_dls = []
    test_names = []
    with open(cfg.data.test_data_config, 'r') as file:
        test_data_config = yaml.load(file, Loader=yaml.FullLoader)

    for test_set in test_data_config:
        logging.info(f"Getting {test_set['manifest']} cuts")
        cutset = CutSet.from_file(test_set['manifest'])
        cutset = cutset.filter(remove_long_short_utterance)
        cutset = cutset.map(unified_sampling_rate)
        lang = test_set.get("lang", "zh")
        cutset = cutset.map(partial(unified_language_code, lang=lang))
        test_name = test_set['name']
        testset = Speech2TextDataset(
            input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
            return_cuts=True,
        )
        testset.s2t_translate_ratio = 1.0 # enforce AST if possible
        sampler = DynamicBucketingSampler(
            cutset,
            max_duration=cfg.data.max_duration,
            shuffle=False,
        )
        test_dl = DataLoader(
            testset,
            batch_size=None,
            sampler=sampler,
            num_workers=cfg.data.num_workers,
        )
        test_dls.append(test_dl)
        test_names.append(test_name)

    return test_names, test_dls

@hydra.main(version_base=None, config_path="configs", config_name="evaluate")
@torch.no_grad()
def main(cfg: DictConfig):
    logging.info("\n" + OmegaConf.to_yaml(cfg))

    # initialize dataloader
    set_audio_duration_mismatch_tolerance(0.1)
    test_sets, test_dls = get_test_dataloaders(cfg)

    tokenizer = AutoTokenizer.from_pretrained(**cfg.tokenizer)

    # Initialize model
    checkpoint_filename = resolve_checkpoint_filename(
        checkpoint_filename=cfg.checkpoint.get("pretrained_model", None),
        epoch=cfg.checkpoint.get("epoch", 0),
        iter=cfg.checkpoint.get("iter", 0),
        avg=cfg.checkpoint.get("avg", 1),
    )
    if checkpoint_filename.startswith('averaged'):
        generate_and_save_averaged_model(cfg.exp_dir, 
                                         epoch=cfg.checkpoint.get("epoch", 0),
                                         iter=cfg.checkpoint.get("iter", 0),
                                         avg=cfg.checkpoint.get("avg", 1))
    model = AutoModel.from_pretrained(
        exp_dir=cfg.exp_dir,
        checkpoint_filename=checkpoint_filename,
    )

    device = torch.device("cuda", 0) if torch.cuda.is_available() else torch.device("cpu")
    model.to(device)
    model.eval()
    num_param = sum(p.numel() for p in model.parameters())
    logging.info(f"Number of model parameters: {num_param}")

    for test_set_name, test_dl in zip(test_sets, test_dls):
        num_cuts = 0
        results = defaultdict(list)

        for batch_idx, batch in enumerate(test_dl):
            cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
            num_cuts += len(cut_ids)
            
            # at entry, feature is (N, T, C)
            supervisions = batch["supervisions"]
            feature = batch["inputs"].to(device)
            feature_lens = supervisions["num_frames"].to(device)
            output = model.forward_encoder(feature, feature_lens)

            task = supervisions['task'][0]
            src_lang = supervisions["language"][0]
            if task == 'transcribe':
                texts = supervisions["text"]
                tgt_lang = supervisions["language"][0]
            elif task == 'translate':
                texts = supervisions["text_translated"]
                tgt_lang = supervisions["language_translated"][0]

            if cfg.inference.decoding_method == "attention_beam_search":
                if task == 'translate':
                    if model.config.translate_mode == 'src2tgt':
                        task_token = model.special_to_id[f"<{task}_{tgt_lang}>"]
                    elif model.config.translate_mode == 'src2tgt_unified':
                        task_token = model.special_to_id[f"<{tgt_lang}>"]
                elif task == 'transcribe':
                    task_token = model.special_to_id[f"<{src_lang}>"]

                if cfg.inference.language_aware:
                    language_token = model.special_to_id[f"<{src_lang}>"]
                elif task == 'translate':
                    language_token = model.special_to_id[f"<{src_lang}>"]
                elif task == 'transcribe':
                    language_token = None
                    task_token = None

                decoding_results = attention_beam_search(
                    model=model,
                    encoder_out=output.encoder_out,
                    encoder_out_lens=output.encoder_out_lens,
                    beam_size=cfg.inference.beam_size,
                    language=language_token,
                    task=task_token,
                    sos=model.attention_decoder.sos_id,
                    eos=model.attention_decoder.eos_id,
                    language_topk=cfg.inference.language_topk,
                )
                hyp_tokens = decoding_results.hyps
            elif cfg.inference.decoding_method == "greedy_search":
                hyp_tokens = greedy_search_batch(
                    model=model,
                    encoder_out=output.encoder_out,
                    encoder_out_lens=output.encoder_out_lens,
                    blank_penalty=cfg.inference.blank_penalty,
                )
            else:
                raise ValueError(f"{cfg.inference.decoding_method}")

            for i, tokens in enumerate(hyp_tokens):
                hyp_tokens[i] = [t if t < tokenizer.vocab_size else tokenizer.unk_id for t in tokens]
            hyps = tokenizer.decode(hyp_tokens)

            hyps = [unified_text_normalize(hyp, tgt_lang).split() for hyp in hyps] 
            texts = [unified_text_normalize(text, tgt_lang).split() for text in texts]
            cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]

            this_batch = []
            assert len(hyps) == len(texts)
            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
                this_batch.append((cut_id, ref_text, hyp_words))
            results[cfg.inference.decoding_method].extend(this_batch)

            if batch_idx % 20 == 1:
                logging.info(f"Processed {num_cuts} cuts already.")

        if cfg.inference.decoding_method == "attention_beam_search" and task == 'translate':
            save_bleu_results(cfg, test_set_name, results)
        else:
            save_results(cfg, test_set_name, results)

        logging.info("Done")


if __name__ == "__main__":
    main()
