

"""Top interface of 3 network JSA training.
    S2P: CTC based speech to phone model.
    G2P: CTC based character to phone model, used to generate proposal
    P2G: CTC based phone to BPE model.
"""

__all__ = ["AMTrainer", "build_model", "_parser", "main"]

import re
import os
from ..shared.manager import Manager
from ..shared import coreutils
from ..shared import encoder as model_zoo
from ..shared.data import JSASpeechDataset, JSAsortedPadCollateASR
from ..shared.tokenizer import load

import argparse
import Levenshtein
from typing import *
import ctc_align
from ctcdecode import CTCBeamDecoder
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
import torch.distributed as dist
import jiwer
import cat.shared.tokenizer as tokenizer
# NOTE :
#   1/4 subsampling is used for Conformer model defaultly
#   for other sampling ratios, you need to modify the value.
#   Commonly, you can use a relatively larger value for allowing some margin.
SUBSAMPLING = 4

def check_label_len_for_ctc(
    tupled_mat_label: Tuple[torch.FloatTensor, torch.LongTensor]
):
    """filter the short seqs for CTC/CRF"""
    return tupled_mat_label[0].shape[0] // SUBSAMPLING > tupled_mat_label[1].shape[0]


def filter_hook(dataset):
    return dataset.select(check_label_len_for_ctc)


def main_worker(gpu: int, ngpus_per_node: int, args: argparse.Namespace, **mkwargs):
    coreutils.set_random_seed(args.seed)
    args.gpu = gpu
    args.rank = args.rank * ngpus_per_node + gpu
    torch.cuda.set_device(args.gpu)

    dist.init_process_group(
        backend=args.dist_backend,
        init_method=args.dist_url,
        world_size=args.world_size,
        rank=args.rank,
    )

    if "T_dataset" not in mkwargs:
        mkwargs["T_dataset"] = JSASpeechDataset
    ### 
    # Sardar:
    # TODO: for P2G training
    ###
    # if os.getcwd().split('/')[-1] == "jsa":
    #     mkwargs["T_dataset"] = P2GDataset

    if "collate_fn" not in mkwargs:
        mkwargs["collate_fn"] = JSAsortedPadCollateASR(flatten_target=False)

    if "func_build_model" not in mkwargs:
        mkwargs["func_build_model"] = build_model

    if "_wds_hook" not in mkwargs:
        mkwargs["_wds_hook"] = filter_hook

    if (
        "func_eval" not in mkwargs
        and hasattr(args, "eval_error_rate")
        and args.eval_error_rate
    ):
        mkwargs["func_eval"] = custom_evaluate

    mkwargs["args"] = args
    manager = Manager(**mkwargs)

    if args.ld is None:
        tr_dataset = manager.trainloader.dl.dataset
        coreutils.distprint(
            f"  total {tr_dataset.__len__()} utterances are used in training.", args.gpu
        )

    # training
    manager.run(args)

def unique(x, dim=-1):
    unique, inverse = torch.unique(x, sorted=False, return_inverse=True, dim=dim)
    perm = torch.arange(inverse.size(dim), dtype=inverse.dtype, device=inverse.device)
    inverse_, perm = inverse.flip([dim]), perm.flip([dim])

    ordered = inverse_.new_empty(unique.size(dim)).scatter_(dim, inverse_, perm)
    return (unique, inverse, ordered)

def subsetIndex(alist, blist):
    idx_list = []
    for idx, id in enumerate(alist):
        if id in blist:
            idx_list.append(idx)
    return idx_list

class AMTrainer(nn.Module):
    def __init__(
        self,
        s2p_encoder: model_zoo.AbsEncoder,
        phn_decoder: CTCBeamDecoder,
        p2g_encoder: model_zoo.AbsEncoder,
        p2g_decoder: model_zoo.AbsEncoder,
        g2p_encoder: model_zoo.AbsEncoder,
        bpe_tokenizer: tokenizer,
        n_samples: int,
        cache_enabled: bool = False,
        add_supervised: bool = False,
        supervised_trans: str = None
        
    ):
        super().__init__()

        self.s2p_encoder = s2p_encoder
        self.phn_searcher = phn_decoder
        self.p2g_encoder = p2g_encoder
        self.p2g_decoder = p2g_decoder
        self.g2p_encoder = g2p_encoder
        # self.s2p_weight = 0.6
        self.dtype = torch.float32
        self.n_samples = n_samples
        self.bos_id = 0
        self.cache_enabled = cache_enabled
        self.add_supervised = add_supervised
        self.supervised_trans = supervised_trans
        if self.cache_enabled:
            self.zlist = {}
            
        self.ctc_loss = nn.CTCLoss(reduction='none',zero_infinity=True)
        self.bpe_tokenizer = load(bpe_tokenizer)
       
        if self.add_supervised:
            assert os.path.isfile(self.supervised_trans), "Spervised trans file is not found."
            supervised_uid = []
            with open(self.supervised_trans, 'r', encoding='utf-8') as f:
                for line in f:
                    uid = re.split('\t| ', line.strip(), maxsplit=1)[0]
                    supervised_uid.append(uid)
            self.supervised_uid = supervised_uid

    def forward(self, x, lx, y, ly, uids, y_pid, ly_pid, y_char, ly_char):
        # s2p_encoder forward
        logits_s2p_enc, logits_lens_s2p_enc = self.s2p_encoder(x, lx)
        logits_s2p_enc = torch.log_softmax(logits_s2p_enc, dim=-1)

        device = x.device
        batch_size = ly.shape[0]
        ly = ly.to(torch.int)
        acc_NLL = torch.zeros([batch_size], dtype=torch.float64, device=device)
        if self.training:
            acc_cnt = torch.zeros([1], dtype=torch.int32)
            acc_g2p_loss = torch.zeros([1], dtype=torch.float32, device=device)
            acc_s2p_loss = torch.zeros([1], dtype=torch.float32, device=device)
            acc_p2g_loss = torch.zeros([1], dtype=torch.float32, device=device)
            total_loss = torch.zeros([batch_size], dtype=torch.float32, device=device)
            acc_p2g_per = torch.zeros([1], dtype=torch.float32)
            acc_ac_per = torch.zeros([1], dtype=torch.float32)
            acc_inf_lens = torch.zeros([1], dtype=torch.int32, device=device)
            acc_NLL_lens = torch.zeros([batch_size], dtype=torch.int32, device=device)
            logits_s2p_enc = logits_s2p_enc.transpose(0, 1)
            ones = torch.ones([batch_size], dtype=torch.int32, device=device)
            block_uniq = torch.zeros([1], dtype=torch.int32, device=device)

            # get z_old from cache
            if self.cache_enabled:
                z_old = [torch.tensor(self.zlist.get(uid, []),dtype=torch.int32, device=device) for uid in uids]
                zlens_old = torch.tensor([z.shape[0] for z in z_old],dtype=torch.int32, device=device)
                z_old_in_batch, zlens_old_in_batch = self.validate_zlen_and_pad(z_old, zlens_old)
                p_old = None
            else:
                z_old = [torch.tensor([],dtype=torch.int32, device=device) for uid in uids]
                zlens_old = torch.zeros([batch_size], dtype=torch.int32, device=device)
                p_old = torch.zeros([batch_size], dtype=torch.float32, device=device)

            # sampling
            logits_g2p_enc, logits_lens_g2p_enc = self.g2p_encoder(y_char, ly_char)
            logits_g2p_enc = torch.log_softmax(logits_g2p_enc, dim=-1)
            samples, sample_lens = self._sample(logits_g2p_enc.detach().exp(), logits_lens_g2p_enc)
            if self.cache_enabled:
                unique_sentence = torch.ones([batch_size*self.n_samples], dtype=torch.float32, device=device)
            else:
                unique_sentence = []
                for i in range(batch_size):
                    line = set()
                    for j in range(self.n_samples):
                        line.add(tuple(samples[i][j][:sample_lens[i][j]].tolist()))
                    unique_sentence.append(len(line))
                unique_sentence = torch.tensor(unique_sentence, dtype=torch.float32, device=device)
            samples = samples.transpose(0,1).cuda()
            sample_lens = sample_lens.transpose(0,1).cuda()
            logits_g2p_enc = logits_g2p_enc.transpose(0, 1)
            
            isFirstBeam = True
            for bantched_sample, zlens_new in zip(samples, sample_lens):
                z_new = [bantched_sample[batch][:zlens_new[batch]] for batch in range(batch_size)]
                if not self.repeated(z_old, z_new, uids):
                    z_new_in_batch, zlens_new_in_batch = self.validate_zlen_and_pad(z_new, zlens_new)
                    with torch.no_grad():
                        # g2p forward and calculate loss
                        g2p_loss_new = self.ctc_loss(logits_g2p_enc, z_new_in_batch, logits_lens_g2p_enc.to(torch.int).cpu(), zlens_new_in_batch.cpu()) / zlens_new_in_batch

                        # s2p forward and calculate loss
                        s2p_loss_new = self.ctc_loss(logits_s2p_enc, z_new_in_batch, logits_lens_s2p_enc.to(torch.int).cpu(), zlens_new_in_batch.cpu()) / zlens_new_in_batch

                        # p2g forward and calculate loss 
                        logits_p2g_enc, logits_lens_p2g_enc = self.p2g_encoder(z_new_in_batch, zlens_new_in_batch)
                        logits_p2g_enc = torch.log_softmax(logits_p2g_enc, dim=-1)
                        p2g_loss_new = self.ctc_loss(logits_p2g_enc.transpose(0, 1).to(torch.float32), y, logits_lens_p2g_enc.to(torch.int).cpu(), ly.cpu()) / ly

                        # calculate MH ratio
                        p_new = g2p_loss_new - s2p_loss_new - p2g_loss_new
                        if isFirstBeam and len(torch.nonzero(zlens_old == 0)) > (batch_size//4):
                            z_old = z_new       # List[tensor,tensor,...]
                            zlens_old = zlens_new
                            p_old = p_new  # tensor(int32) in cuda
                            isFirstBeam = False
                            acc_cnt += torch.tensor(batch_size, dtype=torch.int32)
                        else:
                            if p_old is None:
                                z_old_in_batch, zlens_old_in_batch = self.validate_zlen_and_pad(z_old, zlens_old)
                                g2p_loss_old = self.ctc_loss(logits_g2p_enc, z_old_in_batch, logits_lens_g2p_enc.to(torch.int).cpu(), zlens_old_in_batch.cpu()) / zlens_old_in_batch
                                s2p_loss_old = self.ctc_loss(logits_s2p_enc, z_old_in_batch, logits_lens_s2p_enc.to(torch.int).cpu(), zlens_old_in_batch.cpu()) / zlens_old_in_batch
                                logits_p2g_enc_old, logits_lens_p2g_enc_old = self.p2g_encoder(z_old_in_batch.cuda(), zlens_old_in_batch.cuda())
                                logits_p2g_enc_old = torch.log_softmax(logits_p2g_enc_old, dim=-1)
                                p2g_loss_old = self.ctc_loss(logits_p2g_enc_old.transpose(0, 1).to(torch.float32), y, logits_lens_p2g_enc_old.to(torch.int).cpu(), ly.cpu()) / ly
                                p_old = g2p_loss_old - s2p_loss_old - p2g_loss_old
                            inf_new = ~(~(g2p_loss_new == 0) * ~(s2p_loss_new == 0) * ~(p2g_loss_new == 0))
                            accpet_index = self.accept_reject(p_old, p_new, zlens_old, inf_new)
                            if accpet_index.any():
                                for i in accpet_index:
                                    z_old[i] = z_new[i]
                                zlens_old[accpet_index] = zlens_new[accpet_index]
                                p_old[accpet_index] = p_new[accpet_index]
                                acc_cnt += torch.tensor(accpet_index.shape[0], dtype=torch.int32)
                    if self.add_supervised:
                        z_old, zlens_old = self.replace_supervised(z_old, zlens_old, uids, y_pid, ly_pid)
                    z_old_in_batch, zlens_old_in_batch = self.validate_zlen_and_pad(z_old, zlens_old)
                    g2p_loss = self.ctc_loss(logits_g2p_enc, z_old_in_batch, logits_lens_g2p_enc.to(torch.int).cpu(), zlens_old_in_batch.cpu()) / zlens_old_in_batch
                    s2p_loss = self.ctc_loss(logits_s2p_enc, z_old_in_batch, logits_lens_s2p_enc.to(torch.int).cpu(), zlens_old_in_batch.to(torch.int).cpu()) / zlens_old_in_batch
                    logits_p2g, logits_lens_p2g = self.p2g_encoder(z_old_in_batch.cuda(), zlens_old_in_batch)
                    logits_p2g = torch.log_softmax(logits_p2g, dim=-1).transpose(0, 1)
                    p2g_loss = self.ctc_loss(logits_p2g, y, logits_lens_p2g.to(torch.int).cpu(), ly.cpu()) / ly
                    inf_old = ~(g2p_loss == 0) * ~(s2p_loss == 0) * ~(p2g_loss == 0)
                    g2p_loss = g2p_loss * inf_old
                    s2p_loss = s2p_loss * inf_old
                    p2g_loss = p2g_loss * inf_old
                    ac_per, snt_per = self.get_wer("PER", z_old_in_batch, y_pid, zlens_old_in_batch, ly_pid)
                    per_g2p, snt_per_new = self.get_wer("PER", z_new_in_batch, y_pid, zlens_new_in_batch, ly_pid)
                    inf_new = ~(g2p_loss_new == 0) * ~(s2p_loss_new == 0) * ~(p2g_loss_new == 0)
                    per_g2p = torch.mean(snt_per_new[inf_new.cpu()])
                    block_uniq += 1
                else:
                    acc_cnt += torch.tensor(batch_size, dtype=torch.int32)
                

                total_loss += g2p_loss
                total_loss += s2p_loss
                total_loss += p2g_loss
                acc_g2p_loss += torch.mean(g2p_loss.data[inf_old])
                acc_s2p_loss += torch.mean(s2p_loss.data[inf_old])
                acc_p2g_loss += torch.mean(p2g_loss.data[inf_old])
                acc_inf_lens += len(torch.nonzero(inf_old))
                acc_ac_per += torch.mean(snt_per[inf_old.cpu()])
                acc_p2g_per += per_g2p

                inf_NLL = ~torch.isinf(p_old)
                acc_NLL_lens[inf_NLL] += ones[inf_NLL]
                acc_NLL[inf_NLL] += torch.exp(p_old[inf_NLL])
            
            # store z_old into cache
            if self.cache_enabled:
                for i, uid in enumerate(uids):
                    self.zlist[uid] = z_old[i].tolist()
            NLL = torch.mean(-torch.log(acc_NLL / acc_NLL_lens))
            return torch.sum(total_loss)/acc_inf_lens, torch.sum(acc_g2p_loss)/self.n_samples, torch.sum(acc_s2p_loss)/self.n_samples, torch.mean(acc_p2g_loss)/self.n_samples, acc_cnt / (self.n_samples * batch_size), acc_ac_per/self.n_samples, acc_p2g_per/self.n_samples, NLL, sum(unique_sentence) / (self.n_samples * batch_size), block_uniq/self.n_samples
        
        else:
            # ctc beamSearch decoding for testing
            beam_results, _, _, result_lens = self.phn_searcher.decode(logits_s2p_enc, logits_lens_s2p_enc)
            # best beam, beam_results: (bach, beam, T) -> (bach, T)
            beam_results = beam_results.transpose(0,1)[0]
            # result_lens: (bach, beam) -> (bach)
            lens = result_lens.transpose(0,1)[0]
            z_new = [beam_results[batch][:lens[batch]] for batch in range(batch_size)]
            z_new_in_batch, lens_in_batch = self.validate_zlen_and_pad(z_new, lens)
            # calculate s2p and p2g loss
            s2p_loss = self.ctc_loss(logits_s2p_enc.transpose(0, 1), z_new_in_batch, logits_lens_s2p_enc.to(torch.int).cpu(), lens_in_batch.to(torch.int).cpu()) / lens_in_batch.cuda()
            logits_p2g, logits_lens_p2g = self.p2g_encoder(z_new_in_batch.cuda(), lens_in_batch.cuda())
            logits_p2g = torch.log_softmax(logits_p2g, dim=-1)
            p2g_loss = self.ctc_loss(logits_p2g.transpose(0, 1).to(torch.float32), y, logits_lens_p2g.to(torch.int).cpu(), ly.cpu()) / ly
            
            # calculate PER from s2p and WER from p2g
            per, snt_per = self.get_wer("PER", z_new_in_batch, y_pid, lens_in_batch, ly_pid)
            y_beam_results, _, _, y_result_lens = self.p2g_decoder.decode(logits_p2g, logits_lens_p2g)
            y_new = y_beam_results.transpose(0,1)[0]
            ylens_new = y_result_lens.transpose(0,1)[0]
            wer, snt_wer = self.get_wer("WER", y_new, y, ylens_new, ly)

            # calculate NLL use 100 samples
            logits_g2p_enc, logits_lens_g2p_enc = self.g2p_encoder(y_char, ly_char)
            logits_g2p_enc = torch.log_softmax(logits_g2p_enc, dim=-1)
            samples, sample_lens = self._sample(logits_g2p_enc.detach().exp(), logits_lens_g2p_enc, n_samples=100)
            samples = samples.transpose(0,1).cuda()
            sample_lens = sample_lens.transpose(0,1).cuda()
            logits_g2p_enc = logits_g2p_enc.transpose(0, 1)

            for results, lens in zip(samples, sample_lens): 
                z_sample = [results[batch][:lens[batch]] for batch in range(batch_size)]
                z_sample_in_batch, zlens_sample_in_batch = self.validate_zlen_and_pad(z_sample, lens)
                
                s2p_loss_sample = self.ctc_loss(logits_s2p_enc.transpose(0, 1), z_sample_in_batch, logits_lens_s2p_enc.to(torch.int).cpu(), zlens_sample_in_batch.to(torch.int).cpu()) / zlens_sample_in_batch.cuda()
                logits_p2g, logits_lens_p2g = self.p2g_encoder(z_sample_in_batch.cuda(), zlens_sample_in_batch.cuda())
                logits_p2g = torch.log_softmax(logits_p2g, dim=-1).transpose(0, 1)
                p2g_loss_sample = self.ctc_loss(logits_p2g, y, logits_lens_p2g.cpu(), ly.cpu()) / ly
                g2p_loss_sample = self.ctc_loss(logits_g2p_enc, z_sample_in_batch, logits_lens_g2p_enc.to(torch.int).cpu(), zlens_sample_in_batch.cpu()) / zlens_sample_in_batch
                
                p = g2p_loss_sample - s2p_loss_sample - p2g_loss_sample
                acc_NLL += torch.exp(p)
            NLL = torch.mean(-torch.log(acc_NLL/100))

            return torch.mean(s2p_loss).cuda(), torch.mean(p2g_loss), per, wer, NLL
        
    def _sample(self, probs, lx, n_samples=None):
        N, T, V = probs.shape
        K = n_samples if n_samples else self.n_samples
        # (NT, K)
        samples = torch.multinomial(probs.view(-1, V), K, replacement=True).view(
            N, T, K
        )
        # (N, T, K) -> (N, K, T) -> (N*K, T)
        ys, ly = ctc_align.align_(
            samples.transpose(1, 2).contiguous().view(-1, T),
            # (N, ) -> (N, 1) -> (N, K) -> (N*K, )
            lx.unsqueeze(1).repeat(1, K).contiguous().view(-1),
        )
        return ys.view(N, K, T), ly.view(N, K)

    def accept_reject(self, old, new, zlens_old, inf_new):
        # dif = new - old
        # MH_ratio = torch.exp(dif).cpu()
        e_new = torch.exp(new).cpu()
        e_old = torch.exp(old).cpu()
        MH_ratio = torch.div(e_new, e_old)
        rand = torch.rand(MH_ratio.size()[0])
        yes_no = MH_ratio > rand
        # Always accept when chace[uid] is empty
        yes_no[zlens_old == 0] = True
        # Always reject when loss is inf
        yes_no[inf_new] = False
        accpet_index = torch.nonzero(yes_no).squeeze(dim=1)
        return accpet_index
    
    def repeated(self, z_old, z_new, uids):
        if self.cache_enabled:
            return False
        for i in range(len(z_new)):
            if self.add_supervised and uids[i] in self.supervised_uid:
                continue
            if len(z_old[i]) != len(z_new[i]):
                return False
            dif_index = torch.nonzero(z_old[i] != z_new[i])
            if len(dif_index) > 0:
                return False
        return True

    def replace_supervised(self, z, zlens, uids, labels_pid, ly_pid):
        index = []
        for idx, id in enumerate(uids):
            if id in self.supervised_uid:
                index.append(idx)
        if len(index) > 0:
            for i in index:
                z[i] = labels_pid[i][:ly_pid[i]]
                zlens[i] = ly_pid[i]
        return z, zlens
    
    def validate_zlen_and_pad(self, zlists, zlens):
        zlist_new = zlists.copy()
        zlens_new = zlens.clone()
        if (zlens_new == 0).any():
            # (num_utt % batch_size) item not covered in self.zlist because of mini-batch sampling
            index = torch.nonzero(zlens_new == 0).squeeze(dim=1)
            for i in index:
                zlens_new[i] = 1
                zlist_new[i] = torch.ones([1], dtype=torch.int32, device=zlens.device)

        return pad_sequence(zlist_new, batch_first=True, padding_value=0), zlens_new
    
    def clean_unpickable_objs(self):
        pass

    def get_wer(
        self, type, xs: torch.Tensor, ys: torch.Tensor, lx: torch.Tensor, ly: torch.Tensor
    ):
        acc_err = 0.
        acc_cnt = 0
        snt_per = []
        for x, xlen, y, ylen in zip(xs, lx, ys, ly):
            if type == "PER":
                x1 = [x[:xlen].tolist()]
                y1 = [y[:ylen].tolist()]
                err, cnt = cal_wer(y1, x1)
            elif type == "WER":
                x1 = [self.bpe_tokenizer.decode(x.tolist())]
                y1 = [self.bpe_tokenizer.decode(y.tolist())]
                measure = jiwer.compute_measures(y1, x1)
                cnt = measure['hits'] + measure['substitutions'] + measure['deletions']
                err = measure['substitutions'] + measure['deletions'] + measure['insertions']
            else:
                raise TypeError(f"type {type} is illegal!")
            acc_err += err
            acc_cnt += cnt
            snt_per.append(err / cnt)
        return torch.tensor([acc_err / acc_cnt], dtype=torch.float16), torch.tensor(snt_per, dtype=torch.float16)
    
def cal_wer(gt: List[List[int]], hy: List[List[int]]) -> Tuple[int, int]:
    """compute error count for list of tokens"""
    assert len(gt) == len(hy)
    err = 0
    cnt = 0
    for i in range(len(gt)):
        err += Levenshtein.distance(
            "".join(chr(n) for n in hy[i]), "".join(chr(n) for n in gt[i])
        )
        cnt += len(gt[i])
    return (err, cnt)

@torch.no_grad()
def custom_evaluate(testloader, args: argparse.Namespace, manager: Manager) -> float:
    model = manager.model
    cnt_tokens = 0
    cnt_err = 0
    n_proc = dist.get_world_size()

    for minibatch in tqdm(
        testloader,
        desc=f"Epoch: {manager.epoch} | eval",
        unit="batch",
        disable=(args.gpu != 0),
        leave=False,
    ):
        feats, ilens, labels, olens = minibatch[:4]
        feats = feats.cuda(args.gpu, non_blocking=True)
        ilens = ilens.cuda(args.gpu, non_blocking=True)
        labels = labels.cuda(args.gpu, non_blocking=True)
        olens = olens.cuda(args.gpu, non_blocking=True)

        part_cnt_err, part_cnt_sum = model.module.get_wer(feats, labels, ilens, olens)
        cnt_err += part_cnt_err
        cnt_tokens += part_cnt_sum

    gather_obj = [None for _ in range(n_proc)]
    dist.gather_object(
        (cnt_err, cnt_tokens), gather_obj if args.rank == 0 else None, dst=0
    )
    dist.barrier()
    if args.rank == 0:
        l_err, l_sum = list(zip(*gather_obj))
        wer = sum(l_err) / sum(l_sum)
        manager.writer.add_scalar("loss/dev-token-error-rate", wer, manager.step)
        scatter_list = [wer]
    else:
        scatter_list = [None]

    dist.broadcast_object_list(scatter_list, src=0)
    return scatter_list[0]


def build_beamdecoder(cfg: dict) -> CTCBeamDecoder:
    """
    beam_size:
    num_classes:
    kenlm:
    alpha:
    beta:
    ...
    """

    assert "num_classes" in cfg, "number of vocab size is required."

    if "kenlm" in cfg:
        labels = [str(i) for i in range(cfg["num_classes"])]
        labels[0] = "<s>"
        labels[1] = "<unk>"
    else:
        labels = [""] * cfg["num_classes"]

    return CTCBeamDecoder(
        labels=labels,
        model_path=cfg.get("kenlm", None),
        beam_width=cfg.get("beam_size", 16),
        alpha=cfg.get("alpha", 1.0),
        beta=cfg.get("beta", 0.0),
        num_processes=cfg.get("num_processes", 6),
        log_probs_input=True,
        is_token_based=("kenlm" in cfg),
    )

def build_model(
    cfg: dict,
    args: Optional[Union[argparse.Namespace, dict]] = None,
    dist: bool = True,
    wrapper: bool = True,
) -> Union[nn.parallel.DistributedDataParallel, AMTrainer, model_zoo.AbsEncoder]:
    """
    for ctc-crf training, you need to add extra settings in
    cfg:
        trainer:
            use_crf: true/false,
            lamb: 0.01,
            den_lm: xxx

            decoder:
                beam_size:
                num_classes:
                kenlm:
                alpha:
                beta:
                ...
        ...
    """
    if "trainer" not in cfg:
        cfg["trainer"] = {}

    assert "s2p_encoder" in cfg
    s2p_netconfigs = cfg["s2p_encoder"]
    s2p_net_kwargs = s2p_netconfigs["kwargs"]  # type:dict

    n_classes = s2p_net_kwargs.pop("n_classes")
    s2p_encoder = getattr(model_zoo, s2p_netconfigs["type"])(
        num_classes = n_classes, **s2p_net_kwargs
    )  # type: model_zoo.AbsEncoder

    # initialize beam searcher
    if "decoder" in cfg["trainer"]:
        cfg["trainer"]["decoder"] = build_beamdecoder(cfg["trainer"]["decoder"])
    
    assert "sampler" in cfg
    sampler_cfg = cfg["sampler"]

    assert "p2g_encoder" in cfg
    p2g_enc_configs = cfg["p2g_encoder"]
    p2g_enc_kwargs = p2g_enc_configs["kwargs"]  # type:dict
    p2g_encoder = getattr(model_zoo, p2g_enc_configs["type"])(**p2g_enc_kwargs)  # type: model_zoo.AbsEncoder

    assert "g2p_encoder" in cfg
    g2p_enc_configs = cfg["g2p_encoder"]
    g2p_enc_kwargs = g2p_enc_configs["kwargs"]  # type:dict
    n_classes = g2p_enc_kwargs.pop("n_classes")
    g2p_encoder = getattr(model_zoo, g2p_enc_configs["type"])(
        num_classes = n_classes, **g2p_enc_kwargs)  # type: model_zoo.AbsEncoder
    
    assert "beamDecoder" in cfg
    beamDecoder_cfg = cfg["beamDecoder"]
    phn_searcher = eval(beamDecoder_cfg["type"])(
        [""] * beamDecoder_cfg["n_classes"],
        beam_width=beamDecoder_cfg["beam_width"],
        log_probs_input=beamDecoder_cfg["log_probs_input"],
        num_processes=beamDecoder_cfg["num_processes"]
    )
    p2g_decoder = eval(beamDecoder_cfg["type"])(
            [""] * p2g_enc_kwargs["num_classes"],
            beam_width=beamDecoder_cfg["beam_width"],
            log_probs_input=beamDecoder_cfg["log_probs_input"],
            num_processes=beamDecoder_cfg["num_processes"]
        )
    
    if s2p_netconfigs.get("freeze", False):
        s2p_encoder.requires_grad_(False)
    if p2g_enc_configs.get("freeze", False):
        p2g_encoder.requires_grad_(False)
    if g2p_enc_configs.get("freeze", False):
        g2p_encoder.requires_grad_(False)

    model = AMTrainer(s2p_encoder,
                      phn_searcher,
                      p2g_encoder,
                      p2g_decoder,
                      g2p_encoder,
                      **sampler_cfg)
    # model = AMTrainer(am_model, **cfg["trainer"])
    if not dist:
        return model

    assert args is not None, f"You must tell the GPU id to build a DDP model."
    if isinstance(args, argparse.Namespace):
        args = vars(args)
    elif not isinstance(args, dict):
        raise ValueError(f"unsupport type of args: {type(args)}")
    
    # make batchnorm synced across all processes
    model = coreutils.convert_syncBatchNorm(model)

    model.cuda(args["gpu"])
    model = torch.nn.parallel.DistributedDataParallel(model, 
                                                      device_ids=[args["gpu"]],
                                                      find_unused_parameters=True)
                                                      
    init_checkpoint = OrderedDict()
    if "init_model" in s2p_netconfigs:
        coreutils.distprint(f"> initialize s2p_encoder from: {s2p_netconfigs['init_model']}", args["gpu"])
        s2p_enc_checkpoint = torch.load(
            s2p_netconfigs["init_model"], 
            map_location=f"cuda:{args['gpu']}"
        )["model"]  # type: OrderedDict
        s2p_enc_checkpoint = translate_checkpoint(s2p_enc_checkpoint, "encoder", "s2p_encoder")
        init_checkpoint.update(s2p_enc_checkpoint)
        del s2p_enc_checkpoint

    if "init_model" in p2g_enc_configs:
        coreutils.distprint(f"> initialize p2g_encoder from: {p2g_enc_configs['init_model']}", args["gpu"])
        p2g_enc_checkpoint = torch.load(
            p2g_enc_configs["init_model"], 
            map_location=f"cuda:{args['gpu']}"
        )["model"]  # type: OrderedDict
        p2g_enc_checkpoint = translate_checkpoint(p2g_enc_checkpoint, "encoder", "p2g_encoder")
        init_checkpoint.update(p2g_enc_checkpoint)
        del p2g_enc_checkpoint

    if "init_model" in g2p_enc_configs:
        coreutils.distprint(f"> initialize g2p_encoder from: {g2p_enc_configs['init_model']}", args["gpu"])
        g2p_enc_checkpoint = torch.load(
            g2p_enc_configs["init_model"], 
            map_location=f"cuda:{args['gpu']}"
        )["model"]  # type: OrderedDict
        g2p_enc_checkpoint = translate_checkpoint(g2p_enc_checkpoint, "encoder", "g2p_encoder")
        init_checkpoint.update(g2p_enc_checkpoint)
        del g2p_enc_checkpoint

    if len(init_checkpoint) != 0:
        model.load_state_dict(init_checkpoint)
        del init_checkpoint

    return model

# FIXME: following codes will be removed soon or later
########## COMPATIBLE ###########
# fmt: off
def translate_checkpoint(state_dict: OrderedDict, old_string: str, new_string: str) -> OrderedDict:
    """Translate checkpoint of previous version of RNN-T so that it could be loaded with the new one."""
    old_string = old_string + '.'
    new_string = new_string + '.'
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        if old_string in k:
            k = k.replace(old_string, new_string, 1)
            new_state_dict[k] = v
    return new_state_dict
# fmt: on
#################################

def _parser():
    parser = coreutils.basic_trainer_parser("CTC trainer.")
    parser.add_argument(
        "--eval-error-rate",
        action="store_true",
        help="Use token error rate for evaluation instead of CTC loss (default). "
        "If specified, you should setup 'decoder' in 'trainer' configuration.",
    )
    return parser


def main(args: argparse.Namespace = None):
    if args is None:
        parser = _parser()
        args = parser.parse_args()

    coreutils.setup_path(args)
    coreutils.main_spawner(args, main_worker)


if __name__ == "__main__":
    print(
        "NOTE:\n"
        "    since we import the build_model() function in cat.ctc,\n"
        "    we should avoid calling `python -m cat.ctc.train`, instead\n"
        "    running `python -m cat.ctc`"
    )
