import logging
import os
import yaml
from pathlib import Path
from functools import partial
from collections import defaultdict
import hydra
import torch
from torch.utils.data import DataLoader
from omegaconf import DictConfig, OmegaConf

from lhotse import CutSet, WhisperFbank, WhisperFbankConfig
from lhotse.dataset import (  # noqa F401 for PrecomputedFeatures
    DynamicBucketingSampler,
    K2SpeechRecognitionDataset,
    PrecomputedFeatures,
    OnTheFlyFeatures
)
from auden.models.model_module import build_model
from auden.utils.icefall_checkpoint import load_checkpoint_from_averaged_model, load_model_params
from auden.utils.text_normalization import text_normalization
from auden.utils.icefall_asr_decode import save_results
from auden.tokenizer.tokenizer_module import build_tokenizer
from auden.models.whisper.decoding import DecodingOptions
from auden.peft.lora.utils import inject_lora_to_model


def language_func(cut, language=None):
    cut.supervisions[0].language = language
    return cut

def get_test_dataloaders(cfg):
    test_dls = []
    test_names = []
    with open(cfg.data.test_data_config, 'r') as file:
        test_data_config = yaml.load(file, Loader=yaml.FullLoader)
        
    for test_set in test_data_config:
        logging.info(f"Getting {test_set['name']} cuts")
        cutset = CutSet.from_file(test_set['manifest'])
        cutset = cutset.map(partial(language_func, language=test_set['language']))
        test_name = test_set['name']
        testset = K2SpeechRecognitionDataset(
            input_strategy=OnTheFlyFeatures(
                WhisperFbank(WhisperFbankConfig(
                    num_filters=cfg.data.num_filters, 
                    device="cpu"
                ))
            ), 
            return_cuts=True,
        )
        sampler = DynamicBucketingSampler(
            cutset,
            max_duration=cfg.data.max_duration,
            shuffle=False,
        )
        test_dl = DataLoader(
            testset,
            batch_size=None,
            sampler=sampler,
            num_workers=cfg.data.num_workers,
        )
        test_dls.append(test_dl)
        test_names.append(test_name)

    return test_names, test_dls


@hydra.main(version_base=None, config_path="configs", config_name="asr_inference")
@torch.no_grad()
def main(cfg: DictConfig):
    logging.info("\n" + OmegaConf.to_yaml(cfg))

    # initialize dataloader
    test_sets, test_dls = get_test_dataloaders(cfg)
    
    # initialize tokenizer
    tokenizer = build_tokenizer(cfg.tokenizer)

    # initialize model
    model = build_model(cfg.model)

    # (Optional) add lora to model
    if hasattr(cfg, "lora") and cfg.lora.use_lora:
        import json
        with open(cfg.lora.config) as f:
            lora_config = json.load(f)
        logging.info(f"LoRA configuration: {lora_config}")
        inject_lora_to_model(model, lora_config)

    device = torch.device("cpu")
    if torch.cuda.is_available():
        device = torch.device("cuda", 0)

    if cfg.inference.epoch == -1:
        logging.info("Decode with vanilla whisper model")
    elif cfg.inference.use_averaged_model:
        filename = f'{cfg.exp_dir}/averaged_epoch{cfg.inference.epoch}_avg{cfg.inference.avg}.pt'
        if os.path.exists(filename):
            logging.info(f"Using an already averaged model {filename}")
            model.load_state_dict(
                torch.load(filename, map_location=torch.device('cpu')), strict=False)
        else:
            load_checkpoint_from_averaged_model(model, cfg.inference, cfg.exp_dir, strict=False)
            logging.info(f"Saving averaged model to {filename}")
            torch.save(model.state_dict(), filename)
    else:
        filename = f'{cfg.exp_dir}/epoch-{cfg.inference.epoch}.pt'
        logging.info(f"Loading model from {filename}")
        checkpoint = torch.load(filename, map_location="cpu")
        if "model" not in checkpoint:
            model.load_state_dict(checkpoint, strict=False)
        else:
            load_model_params(filename, model, strict=False)

    model.eval()
    model.to(device)
    num_param = sum([p.numel() for p in model.parameters()])
    logging.info(f"Number of model parameters: {num_param}")
    # result dir
    os.makedirs(f"{cfg.exp_dir}/greedy_search", exist_ok=True)
    
    
    for test_set_name, test_dl in zip(test_sets, test_dls):
        num_cuts = 0
        try:
            num_batches = len(test_dl)
        except TypeError:
            num_batches = "?"
        # decoding result
        results = defaultdict(list)
        
        for batch_idx, batch in enumerate(test_dl):
            if hasattr(cfg, "num_debug_batch") and batch_idx >= cfg.num_debug_batch:
                break
            feature = batch["inputs"]   # (N, T, C)
            feature = feature.to(device)
            feature = feature.permute(0, 2, 1)  # (N, C, T)

            feature_lens = batch["supervisions"]["num_frames"].to(device)

            MAXT = 3000
            if not cfg.inference.remove_input_restriction:
                feature_lens.fill_(MAXT)
                if feature.shape[2] < MAXT:
                    feature = torch.cat([feature, 
                        torch.zeros(
                            feature.shape[0], 
                            feature.shape[1], 
                            MAXT - feature.shape[2]
                        ).to(device, dtype=feature.dtype),
                    ], 2)
            # TODO(jiahong): whisper pos_emb has a max_len of 30s
            if feature.shape[2] > MAXT:
                feature = feature[:, :, :MAXT]
    
            # assume each dset has one language
            batch_language = None if cfg.no_language else batch["supervisions"]["lang"][0]
            decode_options = DecodingOptions(
                language=batch_language,
                without_timestamps=True,
                tokenizer=tokenizer,
            )
            decoded_results = model.decode(feature, decode_options)

            hyps = [result.text for result in decoded_results]
            hyps = [text_normalization(hyp, case='lower', 
                                    remove_symbols=False, 
                                    remove_diacritics=True, 
                                    space_between_cjk=True,
                                    simplified_chinese=False,
                                    remove_fillers=False).split() for hyp in hyps]
            
            texts = batch["supervisions"]["text"]
            texts = [text_normalization(text, case='lower', 
                                        remove_symbols=False, 
                                        remove_diacritics=True, 
                                        space_between_cjk=True,
                                        simplified_chinese=False,
                                        remove_fillers=False).split() for text in texts]
            cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]

            this_batch = []
            assert len(hyps) == len(texts)
            for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
                this_batch.append((cut_id, ref_text, hyp_words))
            results["greedy_search"].extend(this_batch)
            logging.info(' '.join(ref_text))
            logging.info(' '.join(hyp_words))
            
            num_cuts += len(texts)
            if batch_idx % cfg.inference.log_interval == 0:
                batch_str = f"{batch_idx}/{num_batches}"

                logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")

        save_results(cfg, test_set_name, results)

if __name__ == "__main__":
    main()