import json
import os
from collections import OrderedDict
from pathlib import Path

import accelerate
import numpy as np
import torch
import torch.distributed as dist
import wandb
from configs import dict_2_list
from matplotlib import pyplot as plt
from scipy import stats
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from utils.cosine_scheduler import get_cosine_schedule_with_warmup


class Engine:
    def __init__(self, cfg, mode: str, local_rank: int, world_size: int):
        self.cfg = cfg
        self.topk_dict = {0: "1", 1: "3", 2: "5", 3: "12", 4: "24"}
        self.clip_grad = cfg.train.clip_grad
        self.prior_distribution_name = cfg.model.prior_distribution
        self.kl_gamma = cfg.train.kl_gamma
        if self.clip_grad:
            self.clip_grad_value = cfg.train.clip_grad_value
        self.local_rank = local_rank
        self.world_size = world_size
        self.device = self.cfg.train.device
        self.n_layers_to_use_elbo = self.cfg.model.elbo_use_n_last_outputs
        self.early_stopping = torch.zeros(1).to(self.device)
        self.patience = self.cfg.train.patience
        self.epoch = self.cfg.train.epoch
        self.current_epoch = 0
        self.iteration = 0
        self.model_output_rep = self.cfg.model.output_rep
        if self.model_output_rep == "elbo" and self.n_layers_to_use_elbo == 24:
            self.target_index = 4
            if cfg.model.prior_distribution == "uniform":
                self.target_distribution = (
                    torch.ones(
                        cfg.train.batch_size,
                        cfg.model.encoder_layers,
                        device=self.device,
                    )
                    / cfg.model.encoder_layers
                )
            elif cfg.model.prior_distribution == "geometric":
                self.target_distribution = (
                    torch.tensor(
                        self._init_geometric_distribution(p=cfg.model.p_for_geometric_pmf, k=cfg.model.encoder_layers),
                        device=self.device,
                    )
                    .repeat(cfg.train.batch_size)
                    .reshape(cfg.train.batch_size, cfg.model.encoder_layers)
                )
            elif cfg.model.prior_distribution == "chi2":
                self.target_distribution = self._get_chi2_distribution()
            else:
                print(f"Unknown prior distribution: {cfg.model.prior_distribution}")
                assert False
        elif self.model_output_rep == "weighted_sum_ensemble":
            self.target_index = 4
        else:
            self.target_index = 0
        if cfg.train.validation_metric == "loss":
            self.best_score = 1e6
        elif cfg.train.validation_metric == "wer":
            self.best_score = [1e6] * 5
        else:
            print(f"Validation metric is unknown: {cfg.train.validation_metric}")
            assert False
        if self.local_rank == 0:
            print(f"Set validation metric: {cfg.train.validation_metric}")
        self.wandb_project = self.cfg.train.wandb_project
        self.wandb_train_step_log_interval = self.cfg.train.wandb_train_step_log_interval
        self.wandb_val_epoch_interval = self.cfg.train.wandb_val_epoch_interval
        self.mode = mode

        self.dataloader_factory = asr.utils.dataset.DataloaderFactory(self.cfg.dataset, self.cfg.train)
        self.calculate_score = asr.utils.metric.compute_metrics
        ### prepare meters
        self.loss_meter = asr.utils.avgmeter.AverageMeter(device="cuda")
        self.kl_loss_meter = asr.utils.avgmeter.AverageMeter(device="cuda")
        self.ctc_loss_meter = asr.utils.avgmeter.AverageMeter(device="cuda")

        if self.model_output_rep == "elbo":
            n = self.n_layers_to_use_elbo
        else:
            n = 1

        self.wer_meter = [asr.utils.avgmeter.AverageMeter(device="cuda") for i in range(5)]
        self.cer_meter = [asr.utils.avgmeter.AverageMeter(device="cuda") for i in range(5)]
        self.layerwise_ctc_loss_meter = [
            asr.utils.avgmeter.AverageMeter(device="cuda") for i in range(cfg.model.encoder_layers)
        ]
        self.layerwise_probability_weight = [
            asr.utils.avgmeter.AverageMeter(device="cuda") for i in range(cfg.model.encoder_layers)
        ]

        self.predict_token_recoder = [asr.utils.recoder.ArrayRecorder() for i in range(5)]
        self.predict_word_recoder = [asr.utils.recoder.ArrayRecorder() for i in range(5)]
        self.target_word_recoder = asr.utils.recoder.ArrayRecorder()
        self.target_token_recoder = asr.utils.recoder.ArrayRecorder()
        self.predict_layer_distribution = asr.utils.recoder.ArrayRecorder()
        # self.filenames_recoder = utils.recoder.ArrayRecorder()
        # self.waveform_recoder = utils.recoder.ArrayRecorder()

        self.log_plot_dir = Path(self.cfg.workshop) / "distribution_predicted"
        # Path(self.log_plot_dir).mkdir(exist_ok=True, parents=True)

        if self.local_rank == 0:
            wandb.init(project=self.wandb_project, mode=self.cfg.train.wandb_mode)
            print("Main pid:", os.getpid())

    def _get_chi2_distribution(self):
        lower_bound = stats.ncx2.ppf(0.01, self.cfg.model.chi2_df, self.cfg.model.chi2_nc)
        upper_bound = stats.ncx2.ppf(0.99, self.cfg.model.chi2_df, self.cfg.model.chi2_nc)
        x, step = np.linspace(lower_bound, upper_bound, self.cfg.model.encoder_layers, retstep=True)
        p = stats.ncx2.pdf(x, self.cfg.model.chi2_df, self.cfg.model.chi2_nc)
        probas = p * step
        probas_sum = sum(probas)
        probas[0] -= probas_sum - 1
        probas = list(reversed(probas))
        tensor_probas = (torch.tensor(probas, device=self.device).repeat(self.cfg.train.batch_size)).reshape(
            self.cfg.train.batch_size, self.cfg.model.encoder_layers
        )
        return tensor_probas

    @staticmethod
    def _init_geometric_distribution(p, k):
        pmf = [(1 - p) ** i * p for i in range(24)]
        if sum(pmf) > 1:
            pmf[-1] -= sum(pmf) - 1
        elif sum(pmf) < 1:
            pmf[-1] += 1 - sum(pmf)
        return list(reversed(pmf))

    @staticmethod
    def kl(s_logits, t_logits, reduction="batchmean"):
        distillation_loss = F.kl_div(torch.log_softmax(s_logits, dim=-1), t_logits, reduction=reduction)
        return distillation_loss

    def elbo_loss_func(
        self,
        prediction_log_probas,
        prediction_log_probas_len,
        layer_distribution_logits,
        y,
        y_len,
        gamma=0.1,
    ):
        ctc_loss = []
        bs, n_layers = layer_distribution_logits.shape
        layer_distribution_probas = F.softmax(layer_distribution_logits, dim=-1)
        layerwise_ctc_loss = []
        layerwise_probability_weight = []
        if self.cfg.model.prior_distribution == "learnable":
            target_distribution = F.softmax(self.target_distribution).repeat(bs).reshape(bs, n_layers)
        else:
            target_distribution = self.target_distribution[:bs, :]
        for i, log_probas in enumerate(prediction_log_probas):
            loss = self.loss_func(log_probas.transpose(0, 1), y, prediction_log_probas_len[i], y_len) / y_len
            if self.cfg.train.prior_weight_ctc_loss:
                ctc_loss.append(target_distribution[:, i] * loss)
            else:
                ctc_loss.append(layer_distribution_probas[:, i] * loss)
            layerwise_ctc_loss.append((torch.mean(loss)))
            layerwise_probability_weight.append(torch.mean(layer_distribution_probas[:, i]))
        assert len(prediction_log_probas) == 24, "Must compute CTC per layer"
        ctc_loss = torch.sum(torch.stack(ctc_loss), dim=0)  # sum per layer, average per sample
        assert len(ctc_loss) == bs, f"Expected to sum ce loss per layer before averaging, got: {len(ctc_loss)}"
        ctc_loss = torch.mean(ctc_loss)
        if self.cfg.train.disable_kl_loss:
            kl_loss = None
        else:
            kl_loss = self.kl(layer_distribution_logits, target_distribution)
        if kl_loss:
            total_loss = gamma * kl_loss + ctc_loss
        else:
            total_loss = ctc_loss
        return total_loss, kl_loss, ctc_loss, layerwise_ctc_loss, layerwise_probability_weight

    def ensemble_loss_function(self, log_probs, log_probs_len, y, y_len):
        assert len(log_probs) == 24, "Must compute CTC per layer"
        n_layers = len(log_probs)
        layerwise_ctc_loss = []
        ctc_loss = torch.empty(n_layers)
        layer_distribution_probas = F.softmax(torch.ones(n_layers), dim=0)
        for i, log_prob in enumerate(log_probs):
            loss = self.loss_func(log_prob.transpose(0, 1), y, log_probs_len[i], y_len)
            ctc_loss[i] = layer_distribution_probas[i] * loss
            layerwise_ctc_loss.append(loss)
        ctc_loss = torch.mean(ctc_loss)
        return ctc_loss, layerwise_ctc_loss

    def weighted_hidden_loss_function(
        self, prediction_log_probas, prediction_log_probas_len, layer_distribution_logits, y, y_len
    ):
        loss = self.loss_func(
            prediction_log_probas[0].transpose(0, 1),  # (N, T, C) -> (T, N, C)
            y,
            prediction_log_probas_len[0],
            y_len,
        )
        avg_layer_distribution_probas = F.softmax(layer_distribution_logits, dim=1).mean(0)
        return loss, avg_layer_distribution_probas

    def prepare_staff(self):
        """We move this part out of the __init__ function to avoid the weird error:
        DataLoader worker (pid xxx) is killed by signal: Aborted
        This error is probably caused by a conflict between lmdb and ddp.
        """

        ### prepare model, optimizer and scheduler
        self.cfg.model.freeze_cnn = self.cfg.train.freeze_cnn
        self.cfg.model.freeze_upstream = self.cfg.train.freeze_upstream
        self.cfg.model.freeze_cnn = self.cfg.train.freeze_cnn
        model = asr.models.wavlm.WavLMFinetuneWrapper(self.cfg).to(self.device)

        ### prepare dataloader
        self.dataloader_train = self.dataloader_factory.build(
            state="train",
            rank=self.local_rank,
            size=self.cfg.train.ds_size,
            decoder=model.decoder,
            vocab_size=self.cfg.dataset.dictionary_len,
        )
        self.dataloader_val = self.dataloader_factory.build(
            state="val",
            rank=self.local_rank,
            size=self.cfg.train.ds_size,
            decoder=model.decoder,
            vocab_size=self.cfg.dataset.dictionary_len,
        )
        if isinstance(self.cfg.dataset.test_folder, list) and len(self.cfg.dataset.test_folder) > 1:
            self.dataloader_test = [
                self.dataloader_factory.build(
                    state="test",
                    rank=self.local_rank,
                    size=self.cfg.train.ds_size,
                    split=self.cfg.dataset.test_folder[i],
                    decoder=model.decoder,
                    vocab_size=self.cfg.dataset.dictionary_len,
                )
                for i in range(len(self.cfg.dataset.test_folder))
            ]
        self.total_global_steps = len(self.dataloader_train) * self.cfg.train.epoch
        self.blank = self.dataloader_train.dataset.dictionary["<blank>"]
        self.pad_token_index = self.dataloader_train.dataset.dictionary["<pad>"]
        self.loss_func = torch.nn.CTCLoss(
            blank=self.blank, zero_infinity=True, reduction="none" if self.model_output_rep == "elbo" else "mean"
        )

        if self.cfg.train.freeze_cnn:
            for param in model.wavlm.feature_extractor.parameters():
                param.requires_grad = False
        if self.cfg.train.freeze_upstream:
            print("WavLM upstream layers on FREEZE!")
            for param in model.wavlm.parameters():
                param.requires_grad = False

        find_unused_parameters = True if self.cfg.model.output_rep.startswith("layer") else False
        accelerator = accelerate.Accelerator(mixed_precision="bf16")
        model = accelerator.prepare(model)
        self.model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[self.local_rank],
            find_unused_parameters=find_unused_parameters,
        )

        self.optimizer = torch.optim.AdamW(
            params=filter(lambda x: x.requires_grad, self.model.parameters()),
            lr=self.cfg.train.lr,
            weight_decay=self.cfg.train.weight_decay,
        )

        if self.local_rank == 0:
            print(f"Optimizer: {self.cfg.train.optimizer}")

        # CosineAnnealingLR with Warm-up
        self.scheduler = get_cosine_schedule_with_warmup(
            optimizer=self.optimizer,
            num_warmup_steps=self.cfg.train.warmup_steps,
            num_training_steps=self.total_global_steps,
        )

        if self.cfg.train.resume is not None:
            ckpt = torch.load(self.cfg.train.resume, map_location=self.device)
            self.model.module.load_state_dict(ckpt["model"])
            self.optimizer.load_state_dict(ckpt["optimizer"])
            self.scheduler.load_state_dict(ckpt["scheduler"])
            self.current_epoch = ckpt["epoch"] + 1
            self.iteration = ckpt["iteration"]
            self.best_score = ckpt["best_score"]
            if self.local_rank == 0:
                print(f"Resuming from {self.cfg.train.resume}")
                print(f"Best score received is {self.best_score}")
            del ckpt

        ### prepare logger
        if self.local_rank == 0:
            # wandb.watch(self.model, log="all", log_graph=True)
            self.logger_train = asr.utils.logger.create_logger(self.cfg.workshop, name="train")
            self.logger_train.info(f"workshop: {self.cfg.workshop}")
            self.logger_train.info(f"seed: {self.cfg.train.seed}")
            self.logger_train.info(f"pid: {os.getpid()}")

            self.logger_val = asr.utils.logger.create_logger(self.cfg.workshop, name="val")
            self.logger_val.info(f"workshop: {self.cfg.workshop}")
            self.logger_val.info(f"seed: {self.cfg.train.seed}")
            self.logger_val.info(f"pid: {os.getpid()}")
        else:
            self.logger_train = None
            self.logger_val = None
        torch.distributed.barrier()
        self.config_2_json()

    def config_2_json(self, jsonfile=None):
        self.jsonfile = os.path.join(self.cfg.workshop, "config.json") if jsonfile is None else jsonfile
        with open(self.jsonfile, "w") as f:
            json.dump(dict(self.cfg), f, indent=2)

    def json_2_config(self, jsonfile=None):
        if jsonfile is not None:
            self.jsonfile = jsonfile
        assert hasattr(self, "jsonfile"), "Please provide the .json file first."
        with open(self.jsonfile, "r") as f:
            data = json.load(f)
            self.cfg.merge_from_list(dict_2_list(data))

    def reset_meters(self):
        self.loss_meter.reset()
        self.kl_loss_meter.reset()
        self.ctc_loss_meter.reset()
        for meter in self.wer_meter:
            meter.reset()
        for meter in self.cer_meter:
            meter.reset()
        for meter in self.layerwise_ctc_loss_meter:
            meter.reset()
        for meter in self.layerwise_probability_weight:
            meter.reset()

    def reset_recoders(self):
        for token_recoder, word_recorder in zip(self.predict_token_recoder, self.predict_word_recoder):
            token_recoder.reset()
            word_recorder.reset()
        self.target_word_recoder.reset()
        self.target_token_recoder.reset()
        self.predict_layer_distribution.reset()
        # self.filenames_recoder.reset()
        # self.waveform_recoder.reset()

    def gather_distributed_data(self, gather_data, mode="prediction"):
        if isinstance(gather_data, torch.Tensor):
            _output = [torch.zeros_like(gather_data) for _ in range(self.world_size)]
            dist.all_gather(_output, gather_data, async_op=False)
            output = torch.cat(_output)
        else:
            # print(f"Gathered data in gathering: {gather_data}")
            if gather_data[0] is not None:
                _output = [None for _ in range(self.world_size)]
                if hasattr(dist, "all_gather_object"):
                    dist.all_gather_object(_output, gather_data)
                else:
                    asr.utils.distributed.all_gather_object(_output, gather_data, self.world_size)
                output = []
                for rank_data in _output:
                    if mode == "prediction":
                        for batch_data in rank_data:
                            output.extend(batch_data)
                    else:
                        output.extend(rank_data)
            else:
                output = None
        # if mode == "distribution":
        #     print(f"len of output: {len(output)}")
        #     print(f"shape of output: {output.shape}")
        # print(f"Mode: {mode}")
        # print(f"Output len: {len(output)}")
        # print(f"Output: {output}")
        return output

    def _finetune_step(self, data):
        torch.cuda.empty_cache()
        waveform = [torch.from_numpy(i) for i in data[0]]
        waveform = pad_sequence(
            waveform,
            batch_first=True,
            padding_value=self.pad_token_index,
        ).to(self.device)
        batch_size = len(waveform)
        padding_mask = torch.full((batch_size, max(data[1])), fill_value=False, dtype=torch.bool).to(self.device)
        padding_mask = self._get_padding_mask(data[1], padding_mask).to(self.device)
        y = [torch.IntTensor(label) for label in data[2]]
        filenames = [file_name for file_name in data[3]]
        # print(f"Rank: {self.local_rank}, {filenames}")
        y_len = torch.IntTensor([len(label) for label in y]).to(self.device)
        y = pad_sequence(
            y,
            batch_first=True,
            padding_value=self.pad_token_index,
        ).to(self.device)
        self.optimizer.zero_grad()
        # if self.iteration > 5:
        #     with torch.profiler.profile(activities=[
        #         torch.profiler.ProfilerActivity.CPU,
        #         torch.profiler.ProfilerActivity.CUDA,
        #     ]) as p:
        #         (
        #             log_probs,
        #             log_probs_len,
        #             pred_tokens_batch,
        #             pred_words_batch,
        #             target_tokens_batch,
        #             target_words_batch,
        #             layer_distribution_logits,
        #         ) = self.model(waveform, padding_mask, y, inference=False)
        #     print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
        #     assert False
        # else:
        (
            log_probs,
            log_probs_len,
            pred_tokens_batch,
            pred_words_batch,
            target_tokens_batch,
            target_words_batch,
            layer_distribution_logits,
        ) = self.model(waveform, padding_mask, y, inference=False)
        # start_time = time.time()
        if self.cfg.model.output_rep not in ["elbo", "weighted_sum_ensemble", "weighted_hiddens"]:
            loss = self.loss_func(
                log_probs[0].transpose(0, 1),  # (N, T, C) -> (T, N, C)
                y,
                log_probs_len[0],
                y_len,
            )
            kl_loss, ctc_loss, layerwise_ctc_loss, layerwise_probability_weight = None, None, None, None
        elif self.cfg.model.output_rep == "weighted_sum_ensemble":
            loss, layerwise_ctc_loss = self.ensemble_loss_function(log_probs, log_probs_len, y, y_len)
            kl_loss, ctc_loss, layerwise_probability_weight = None, None, None
        elif self.cfg.model.output_rep == "elbo":
            loss, kl_loss, ctc_loss, layerwise_ctc_loss, layerwise_probability_weight = self.elbo_loss_func(
                log_probs,
                log_probs_len,
                layer_distribution_logits,
                y,
                y_len,
                gamma=self.kl_gamma,
            )
        elif self.cfg.model.output_rep == "weighted_hiddens":
            loss, layerwise_probability_weight = self.weighted_hidden_loss_function(
                log_probs, log_probs_len, layer_distribution_logits, y, y_len
            )
            kl_loss, ctc_loss, layerwise_ctc_loss = None, None, None
        # torch.cuda.synchronize()
        # end_time = time.time()
        # shuffle_time = end_time - start_time
        # print(f"Loss computation time: {shuffle_time}")
        loss = loss / self.cfg.train.accumulate_each_n_steps
        loss.backward()
        if self.clip_grad:
            grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_value)
        else:
            grad_norm = None

        if self.iteration % self.cfg.train.accumulate_each_n_steps == 0:
            if self.clip_grad:
                grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_value)
            else:
                grad_norm = None

            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()
        # print(f"TT: {target_tokens_batch}")  # list of n strings (n = bs)
        # print(f"TW: {target_words_batch}")  # list of n lists of m strings (n = bs) (m = target words number)
        # print(f"PT: {pred_tokens_batch}")  # [list of n strings (n = bs)]
        # print(f"PW: {pred_words_batch}")  # [list of n lists of m strings (n = bs) (m = target words number)]

        # if layer_distribution_logits is None:
        #     top_1_prediction = (pred_tokens_batch, pred_words_batch)
        #     self.predict_token_recoder[0].record(top_1_prediction[0])
        #     self.predict_word_recoder[0].record(top_1_prediction[1])
        # else:
        #     pred = (log_probs, log_probs_len)
        #     # TOP 1 Prediction return: pred_tokens_batch, pred_words_batch
        #     top_1_prediction = self.collect_top_k_predictions(layer_distribution_logits, 1, batch_size, pred)
        #     # TOP 3 Prediction
        #     top_3_prediction = self.collect_top_k_predictions(layer_distribution_logits, 3, batch_size, pred)
        #     # TOP 5 Prediction
        #     top_5_prediction = self.collect_top_k_predictions(layer_distribution_logits, 5, batch_size, pred)
        #     # TOP 12 Prediction
        #     top_12_prediction = self.collect_top_k_predictions(layer_distribution_logits, 12, batch_size, pred)
        #     # TOP 24 Prediction
        #     top_24_prediction = self.collect_top_k_predictions(layer_distribution_logits, 24, batch_size, pred)
        #     for i, pred in enumerate(
        #         [
        #             top_1_prediction,
        #             top_3_prediction,
        #             top_5_prediction,
        #             top_12_prediction,
        #             top_24_prediction,
        #         ]
        #     ):
        #         assert pred is not None, f"Expected to get output got prediction: {pred} for iteration: {i}"
        #         self.predict_token_recoder[i].record(pred[0])
        #         self.predict_word_recoder[i].record(pred[1])
        #
        # self.target_word_recoder.record(target_words_batch)
        # self.target_token_recoder.record(target_tokens_batch)

        self.loss_meter.update(loss.item())
        if kl_loss:
            self.kl_loss_meter.update(kl_loss.item())
        if ctc_loss:
            self.ctc_loss_meter.update(ctc_loss.item())
        current_lr = self.optimizer.param_groups[0]["lr"]
        log_dict = {}
        if self.iteration % self.wandb_train_step_log_interval == 0 and self.local_rank == 0:
            log_dict["train_step/total_loss"] = loss.item()
            log_dict["LR"] = current_lr
            if self.model_output_rep == "elbo":
                if kl_loss:
                    log_dict["train_step/kl_loss"] = kl_loss.item()
                log_dict["train_step/ctc_loss"] = ctc_loss.item()
                for i in range(len(layerwise_ctc_loss)):
                    log_dict[f"train_step/{i}_layer_ctc_loss"] = layerwise_ctc_loss[i]
            elif self.model_output_rep == "weighted_sum_ensemble":
                for i in range(len(layerwise_ctc_loss)):
                    log_dict[f"train_step/{i}_layer_ctc_loss"] = layerwise_ctc_loss[i]
            if grad_norm:
                log_dict["grad_norm"] = grad_norm
            if layerwise_probability_weight is not None:
                for i in range(len(layerwise_probability_weight)):
                    log_dict[f"train_step/{i}_layer_probability_weight"] = layerwise_probability_weight[i]
            wandb.log(log_dict, step=self.iteration)

        pbar_train_dic = OrderedDict()
        pbar_train_dic["iter"] = self.iteration
        pbar_train_dic["lr"] = current_lr
        pbar_train_dic["loss"] = f"{self.loss_meter.avg:.5f}"
        if self.model_output_rep == "elbo":
            if kl_loss:
                pbar_train_dic["kl_loss"] = f"{self.kl_loss_meter.avg:.5f}"
            pbar_train_dic["ctc_loss"] = f"{self.ctc_loss_meter.avg:.5f}"
        assert self.loss_meter.avg is not None, "Loss turned to be None"
        return pbar_train_dic

    def collect_topk_ensemble(self, k, pred, layer_distribution_logits):
        log_probs, log_probs_len = pred  # log_probs: [n_layers, bs, seq_len, hid dim], log_probs_len: [n_layers, bs]
        batch_size, seq_len, hid_dim = log_probs[0].shape
        top_k_probas = F.softmax(torch.ones(k), dim=0)
        top_k_proba_predictions = torch.zeros((batch_size, seq_len, hid_dim)).cuda()
        for j in range(k):
            top_k_proba_predictions += log_probs[j] * top_k_probas[j]
        pred_tokens_batch, pred_words_batch = self.model.module.decode(
            top_k_proba_predictions.float().contiguous().cpu(), log_probs_len[0]
        )
        return pred_tokens_batch, pred_words_batch

    def collect_top_k_predictions(self, k, pred, layer_distribution_logits):
        log_probs, log_probs_len = pred  # log_probs: [n_layers, bs, seq_len, hid dim], log_probs_len: [n_layers, bs]
        batch_size, seq_len, hid_dim = log_probs[0].shape
        best_topk_indices = torch.topk(input=layer_distribution_logits, dim=1, k=k).indices
        best_topk_values = F.softmax(torch.topk(input=layer_distribution_logits, dim=1, k=k).values, dim=1)
        top_k_proba_predictions = torch.zeros((batch_size, seq_len, hid_dim)).cuda()
        for i in range(batch_size):
            for j in range(k):
                layer = best_topk_indices[i, j]
                weight = best_topk_values[i, j]
                top_k_proba_predictions[i, :, :] += log_probs[layer][i] * weight
        pred_tokens_batch, pred_words_batch = self.model.module.decode(
            top_k_proba_predictions.float().contiguous().cpu(), log_probs_len[0]
        )
        return pred_tokens_batch, pred_words_batch

    @staticmethod
    def _get_padding_mask(lens, padding_mask):
        for i, length in enumerate(lens):
            pad_length = padding_mask.shape[0] - length
            padding_mask[i, -pad_length:] = True
        return padding_mask

    @torch.no_grad()
    def _finetune_val_step(self, data, ith_layer_inference):
        torch.cuda.empty_cache()
        original_waveform = [torch.from_numpy(i) for i in data[0]]
        waveform = pad_sequence(
            original_waveform,
            batch_first=True,
            padding_value=self.pad_token_index,
        ).to(self.device)
        batch_size = len(waveform)
        padding_mask = torch.full((batch_size, max(data[1])), fill_value=False, dtype=torch.bool).to(self.device)
        padding_mask = self._get_padding_mask(data[1], padding_mask).to(self.device)
        y = [torch.IntTensor(label) for label in data[2]]
        y_len = torch.IntTensor([len(label) for label in y]).to(self.device)
        # filenames = data[3]
        y = pad_sequence(
            y,
            batch_first=True,
            padding_value=self.pad_token_index,
        ).to(self.device)
        (
            log_probs,
            log_probs_len,
            pred_tokens_batch,
            pred_words_batch,
            target_tokens_batch,
            target_words_batch,
            layer_distribution_logits,
        ) = (
            self.model(waveform, padding_mask, y, inference=True)
            if (ith_layer_inference is None)
            else (self.model.module.ith_layer_inference(waveform, ith_layer_inference, padding_mask, y))
        )
        # print(f"pred_words_batch: {pred_words_batch}")
        # print(f"target_tokens_batch: {pred_tokens_batch}")
        if self.model_output_rep in ["elbo", "weighted_sum_ensemble"]:
            top_k_func = (
                self.collect_top_k_predictions if self.model_output_rep == "elbo" else self.collect_topk_ensemble
            )
            # print("Log probs", len(log_probs))  # 24
            # print("Log probs len", len(log_probs_len))  # 24
            #
            # print("Log probs [0] shape", log_probs[0].shape)  # bs, seq_len, 32 (hid dim)
            # print("Log probs len [0] shape", log_probs_len[0].shape)  # bs
            #
            # print("Log probs [0]", log_probs[0])  # 3
            # print("Log probs len [0]", log_probs_len[0])  # 1
            pred = (log_probs, log_probs_len)
            # TOP 1 Prediction return: pred_tokens_batch, pred_words_batch
            top_1_prediction = top_k_func(k=1, pred=pred, layer_distribution_logits=layer_distribution_logits)
            # print(f"top_1_prediction[0][0] len: {len(top_1_prediction[0][0])}")

            # TOP 3 Prediction
            top_3_prediction = top_k_func(k=3, pred=pred, layer_distribution_logits=layer_distribution_logits)

            # print(f"top_3_prediction len: {len(top_3_prediction)}")
            # print(f"top_3_prediction[0] len: {len(top_3_prediction[0])}")
            # print(f"top_3_prediction[0][0] len: {len(top_3_prediction[0][0])}")
            # TOP 5 Prediction
            top_5_prediction = top_k_func(k=5, pred=pred, layer_distribution_logits=layer_distribution_logits)
            # TOP 12 Prediction
            top_12_prediction = top_k_func(k=12, pred=pred, layer_distribution_logits=layer_distribution_logits)
            # TOP 24 Prediction
            top_24_prediction = top_k_func(k=24, pred=pred, layer_distribution_logits=layer_distribution_logits)
        else:
            top_1_prediction = (pred_tokens_batch, pred_words_batch)
            top_3_prediction, top_5_prediction, top_12_prediction, top_24_prediction = (
                None,
                None,
                None,
                None,
            )
        # print(f"Layer distribution: {layer_distribution_logits}")
        if self.cfg.model.output_rep not in ["elbo", "weighted_sum_ensemble", "weighted_hiddens"]:
            loss = self.loss_func(
                log_probs[0].transpose(0, 1),  # (N, T, C) -> (T, N, C)
                y,
                log_probs_len[0],
                y_len,
            )
            kl_loss, ctc_loss, layerwise_ctc_loss, layerwise_probability_weight = None, None, None, None
        elif self.cfg.model.output_rep == "weighted_sum_ensemble":
            loss, layerwise_ctc_loss = self.ensemble_loss_function(log_probs, log_probs_len, y, y_len)
            kl_loss, ctc_loss, layerwise_probability_weight = None, None, None
        elif self.cfg.model.output_rep == "elbo":
            loss, kl_loss, ctc_loss, layerwise_ctc_loss, layerwise_probability_weight = self.elbo_loss_func(
                log_probs,
                log_probs_len,
                layer_distribution_logits,
                y,
                y_len,
                gamma=self.kl_gamma,
            )
        elif self.cfg.model.output_rep == "weighted_hiddens":
            loss, layerwise_probability_weight = self.weighted_hidden_loss_function(
                log_probs, log_probs_len, layer_distribution_logits, y, y_len
            )
            kl_loss, ctc_loss, layerwise_ctc_loss = None, None, None

        if layerwise_probability_weight is not None:
            for i in range(len(layerwise_probability_weight)):
                self.layerwise_probability_weight[i].update(layerwise_probability_weight[i])

        if self.model_output_rep in ["elbo", "weighted_sum_ensemble"]:
            for i, pred in enumerate(
                [
                    top_1_prediction,
                    top_3_prediction,
                    top_5_prediction,
                    top_12_prediction,
                    top_24_prediction,
                ]
            ):
                if pred is not None:
                    self.predict_token_recoder[i].record(pred[0])
                    self.predict_word_recoder[i].record(pred[1])
                else:
                    print(f"Prediction for top_{self.topk_dict[i]} is None!")
                    assert False
            for i in range(len(layerwise_ctc_loss)):
                self.layerwise_ctc_loss_meter[i].update(layerwise_ctc_loss[i])
        else:
            if pred_tokens_batch is not None and pred_words_batch is not None:
                self.predict_token_recoder[0].record(top_1_prediction[0])
                self.predict_word_recoder[0].record(top_1_prediction[1])
            else:
                print(f"Prediction for top_{self.topk_dict[0]} is None!")
                assert False

        self.target_word_recoder.record(target_words_batch)
        self.target_token_recoder.record(target_tokens_batch)
        if self.model_output_rep in ["elbo", "weighted_hiddens"]:
            self.predict_layer_distribution.record(F.softmax(layer_distribution_logits, dim=-1))
        # self.filenames_recoder.record(filenames)
        # self.waveform_recoder.record(original_waveform)

        self.loss_meter.update(loss.item())
        pbar_val_dic = OrderedDict()
        pbar_val_dic["loss"] = f"{self.loss_meter.avg:.5f}"
        if ctc_loss:
            self.ctc_loss_meter.update(ctc_loss.item())
            pbar_val_dic["loss_ctc"] = f"{self.ctc_loss_meter.avg:.5f}"
        if kl_loss:
            self.kl_loss_meter.update(kl_loss.item())
            pbar_val_dic["loss_kl"] = f"{self.kl_loss_meter.avg:.5f}"
        return pbar_val_dic

    def _get_plot_distribution(self, distribution: torch.Tensor, plot_prior: bool) -> Path:
        x = list(range(1, self.cfg.model.encoder_layers + 1))
        posterior_mean = torch.mean(distribution, dim=0).tolist()
        posterior_std = torch.std(distribution, dim=0).tolist()
        if plot_prior:
            if self.cfg.model.prior_distribution == "learnable":
                target_distribution = F.softmax(self.target_distribution)
            else:
                target_distribution = self.target_distribution
            prior_mean = (
                target_distribution.tolist()
                if len(self.target_distribution.shape) == 1
                else self.target_distribution[0, :].tolist()
            )
        fig, ax = plt.subplots(figsize=(12, 4))
        ax.errorbar(x, posterior_mean, yerr=posterior_std, linestyle="None", marker="^")
        if plot_prior:
            ax.errorbar(x, prior_mean, yerr=[0] * self.cfg.model.encoder_layers, linestyle="None", marker="*")
        for i in range(len(x)):
            ax.annotate(
                f"{posterior_mean[i]:.2f}",
                (x[i], posterior_mean[i]),
                textcoords="offset points",
                xytext=(0, 10),
                ha="center",
            )
        ax.set_xticks(range(min(x), max(x) + 1, 1))
        plot_filename = Path(self.log_plot_dir) / f"epoch_{self.current_epoch}.png"
        fig.savefig(plot_filename)
        return plot_filename

    def train_epoch(self):
        torch.cuda.empty_cache()
        self.dataloader_train.set_epoch(self.current_epoch)
        if self.local_rank == 0:
            print(f"-------- {self.cfg.workshop} --------")
        discrip_str = f"Epoch-{self.current_epoch}/{self.epoch}"
        pbar_train = tqdm(self.dataloader_train, disable=self.local_rank != 0, dynamic_ncols=True)
        pbar_train.set_description("Train" + discrip_str)
        self.reset_meters()
        self.reset_recoders()
        self.model.train()
        self.optimizer.zero_grad()
        for data in pbar_train:
            self.iteration += 1
            pbar_train_dic = self._finetune_step(data)
            pbar_train.set_postfix(pbar_train_dic)
        # epoch_token_target = self.gather_distributed_data(self.target_token_recoder.data, mode="target")
        # epoch_word_target = self.gather_distributed_data(self.target_word_recoder.data, mode="target")
        self.loss_meter.sync_distributed()
        self.kl_loss_meter.sync_distributed()
        self.ctc_loss_meter.sync_distributed()
        # if self.model_output_rep == "elbo":
        #     epoch_token_preds = [
        #         self.gather_distributed_data(recoder.data)
        #         for recoder in self.predict_token_recoder[: len(self.topk_dict.keys())]
        #     ]
        #     epoch_word_preds = [
        #         self.gather_distributed_data(recoder.data)
        #         for recoder in self.predict_word_recoder[: len(self.topk_dict.keys())]
        #     ]
        # else:
        #     epoch_token_preds = [self.gather_distributed_data(self.predict_token_recoder[0].data)]
        #     epoch_word_preds = [self.gather_distributed_data(self.predict_word_recoder[0].data)]
        torch.distributed.barrier()
        # print(f"Passed barrier in Train")
        if self.local_rank == 0:
            epoch_loss = self.loss_meter.avg
            # cer, wer = [], []
            log_dict = {}
            if self.model_output_rep == "elbo":
                if not self.cfg.train.disable_kl_loss:
                    epoch_kl_loss = self.kl_loss_meter.avg
                epoch_ctc_loss = self.ctc_loss_meter.avg
            #     for i, pw in enumerate(epoch_word_preds):
            #         _cer, _wer = self.calculate_score(epoch_token_preds[i], pw, epoch_token_target, epoch_word_target)
            #         cer.append(_cer)
            #         wer.append(_wer)
            #         log_dict[f"train/cer{self.topk_dict[i]}"] = _cer
            #         log_dict[f"train/wer{self.topk_dict[i]}"] = _wer
            # else:
            #     _cer, _wer = self.calculate_score(
            #         epoch_token_preds[0],
            #         epoch_word_preds[0],
            #         epoch_token_target,
            #         epoch_word_target,
            #     )
            #     cer.append(_cer)
            #     wer.append(_wer)
            #     log_dict[f"train/cer1"] = _cer
            #     log_dict[f"train/wer1"] = _wer
            log_dict["train_epoch/total_loss"] = epoch_loss
            if self.model_output_rep == "elbo":
                if not self.cfg.train.disable_kl_loss:
                    log_dict["train_epoch/kl_loss"] = epoch_kl_loss
                log_dict["train_epoch/ctc_loss"] = epoch_ctc_loss
            # log_dict["LR"] = self.optimizer.param_groups[0]["lr"]
            wandb.log(log_dict, step=self.iteration)

    def evaluate(self, save_model=True, ith_layer_inference=None):
        assert self.dataloader_val is not None, self.logger_val.info("Validation dataloader is None")
        if self.local_rank == 0:
            print("-------- Validation --------")
        discrip_str = f"Epoch-{self.current_epoch}/{self.epoch}"
        pbar_val = tqdm(self.dataloader_val, disable=self.local_rank != 0, dynamic_ncols=True)
        pbar_val.set_description("Validation" + discrip_str)
        self.reset_meters()
        self.reset_recoders()
        self.model.eval()
        for data in pbar_val:
            pbar_val_dic = self._finetune_val_step(data, ith_layer_inference=ith_layer_inference)
            pbar_val.set_postfix(pbar_val_dic)
        epoch_token_target = self.gather_distributed_data(self.target_token_recoder.data, mode="target")
        epoch_word_target = self.gather_distributed_data(self.target_word_recoder.data, mode="target")
        if self.model_output_rep in ["elbo", "weighted_hiddens"]:
            epoch_layer_dist_preds = self.gather_distributed_data(
                self.predict_layer_distribution.data, mode="distribution"
            )
        print(f"Gathered all data on validation!")
        # if self.current_epoch == 0:
        #     waveforms = self.gather_distributed_data(self.waveform_recoder.data)
        #     filenames = self.gather_distributed_data(self.filenames_recoder.data)
        self.loss_meter.sync_distributed()
        self.kl_loss_meter.sync_distributed()
        self.ctc_loss_meter.sync_distributed()
        if self.model_output_rep in ["elbo", "weighted_sum_ensemble"]:
            epoch_token_preds = [
                self.gather_distributed_data(recoder.data)
                for recoder in self.predict_token_recoder[: len(self.topk_dict.keys())]
            ]
            epoch_word_preds = [
                self.gather_distributed_data(recoder.data)
                for recoder in self.predict_word_recoder[: len(self.topk_dict.keys())]
            ]
            for i in range(len(self.layerwise_ctc_loss_meter)):
                self.layerwise_ctc_loss_meter[i].sync_distributed()
                if self.model_output_rep == "elbo":
                    self.layerwise_probability_weight[i].sync_distributed()
        else:
            epoch_token_preds = [self.gather_distributed_data(self.predict_token_recoder[0].data)]
            epoch_word_preds = [self.gather_distributed_data(self.predict_word_recoder[0].data)]
        torch.distributed.barrier()
        # print("Passed barrier in evaluation")
        if self.local_rank == 0:
            if self.model_output_rep in ["elbo", "weighted_sum_ensemble"]:
                layerwise_ctc_loss_meter = [meter.avg for meter in self.layerwise_ctc_loss_meter]
                if self.model_output_rep == "elbo":
                    layerwise_probability_weight = [meter.avg for meter in self.layerwise_probability_weight]
            epoch_loss = self.loss_meter.avg
            cer, wer = [], []
            log_dict = {}
            # print(f"Len of epoch word preds: {len(epoch_word_preds)}")
            # print(f"Len of epoch token preds: {len(epoch_token_preds)}")
            # #
            # print(f"epoch word preds: {epoch_word_preds}")
            # print(f"epoch token preds: {epoch_token_preds}")
            #
            #
            # print(f"Len of epoch word preds [0][0]: {len(epoch_word_preds[0][0])}")
            # print(f"Len of epoch token preds [0][0]: {len(epoch_token_preds[0][0])}")
            if self.model_output_rep == "elbo":
                if not self.cfg.train.disable_kl_loss:
                    kl_loss = self.kl_loss_meter.avg
                ctc_loss = self.ctc_loss_meter.avg
                for i, pw in enumerate(epoch_word_preds):
                    _cer, _wer = self.calculate_score(epoch_token_preds[i], pw, epoch_token_target, epoch_word_target)
                    cer.append(_cer)
                    wer.append(_wer)
                    log_dict[f"val/cer{self.topk_dict[i]}"] = _cer
                    log_dict[f"val/wer{self.topk_dict[i]}"] = _wer
                for i in range(len(layerwise_ctc_loss_meter)):
                    log_dict[f"val/ctc_loss_layer_{i}"] = layerwise_ctc_loss_meter[i]
                    log_dict[f"val/probability_weight_layer_{i}"] = layerwise_probability_weight[i]
            elif self.model_output_rep == "weighted_sum_ensemble":
                for i, pw in enumerate(epoch_word_preds):
                    _cer, _wer = self.calculate_score(epoch_token_preds[i], pw, epoch_token_target, epoch_word_target)
                    cer.append(_cer)
                    wer.append(_wer)
                    log_dict[f"val/cer{self.topk_dict[i]}"] = _cer
                    log_dict[f"val/wer{self.topk_dict[i]}"] = _wer
                for i in range(len(layerwise_ctc_loss_meter)):
                    log_dict[f"val/ctc_loss_layer_{i}"] = layerwise_ctc_loss_meter[i]
            else:
                _cer, _wer = self.calculate_score(
                    epoch_token_preds[0],
                    epoch_word_preds[0],
                    epoch_token_target,
                    epoch_word_target,
                )
                cer.append(_cer)
                wer.append(_wer)
                log_dict[f"val/cer1"] = _cer
                log_dict[f"val/wer1"] = _wer
            if self.model_output_rep in ["elbo", "weighted_hiddens"]:
                epoch_layer_dist_preds = [t.to("cuda:0") for t in epoch_layer_dist_preds]
                epoch_layer_dist_preds = torch.stack(epoch_layer_dist_preds, dim=0)
                # print(f"Got epoch_layer_dist_preds: {epoch_layer_dist_preds.shape}!")
                plot_path = self._get_plot_distribution(
                    epoch_layer_dist_preds, plot_prior=self.model_output_rep == "elbo"
                )
                log_dict[f"Validation predictions/epoch_{self.current_epoch}"] = wandb.Image(str(plot_path))

            if save_model:
                # best score is wer
                if isinstance(self.best_score, list):
                    for i, _wer in enumerate(wer):
                        if i == self.target_index:
                            if self.best_score[i] <= _wer:
                                self.early_stopping += 1
                                print(f"Best: {self.best_score[i]}, i")
                                print(f"Current: {_wer}")
                            else:
                                self.early_stopping = torch.zeros(1).to(self.device)
                            print(f"Early stopping is: {self.early_stopping} / 3")
                            print("Started model saving...")
                            self.model_save(
                                best_score=min(_wer, self.best_score[i]), is_best=_wer < self.best_score[i], idx=i
                            )
                            print("Ended model saving...")
                        self.best_score[i] = min(self.best_score[i], _wer)
                        log_dict[f"best/wer_top{self.topk_dict[i]}"] = self.best_score[i]
                else:
                    if self.best_score <= epoch_loss:
                        self.early_stopping += 1
                        print(f"Best: {self.best_score}")
                        print(f"Current: {epoch_loss}")
                    else:
                        self.early_stopping = torch.zeros(1).to(self.device)
                    print(f"Early stopping is: {self.early_stopping} / 3")
                    print("Started model saving...")
                    self.model_save(
                        best_score=min(self.best_score, epoch_loss),
                        is_best=epoch_loss < self.best_score,
                        idx=self.target_index,
                    )
                    print("Ended model saving...")
                    self.best_score = min(self.best_score, epoch_loss)
                    log_dict[f"best/total_loss_val"] = self.best_score
            log_dict["val/total_loss"] = epoch_loss
            if self.model_output_rep == "elbo":
                log_dict["val/ctc_loss"] = ctc_loss
                if not self.cfg.train.disable_kl_loss:
                    log_dict["val/kl_loss"] = kl_loss
            # if self.current_epoch == 0:
            # for i, wav in enumerate(waveforms):
            #     log_dict[f"audio_val/{filenames[i]}"] = wandb.Audio(wav.cpu().squeeze(0), caption=filenames[i],
            #                                                         sample_rate=16000)

            columns = ["Target Transcription", "Predicted Transcription"]
            data = []
            predicted_transcription = epoch_word_preds[self.target_index]
            for tt, pt in zip(epoch_word_target, predicted_transcription):
                tt = " ".join(tt)
                pt = " ".join(pt)
                data.append([tt, pt])
            table = wandb.Table(data=data, columns=columns)
            log_dict["val_examples"] = table
            wandb.log(log_dict, step=self.iteration)
            self.logger_val.info(
                f"Testing epoch: {self.current_epoch}, "
                f"wer: {wer[self.target_index]:.5f}, "
                f"cer: {cer[self.target_index]:.5f}, "
            )
        dist.broadcast(self.early_stopping, src=0)
        print(f"Broadcast early stopping value to rank {self.local_rank}: {self.early_stopping} / 3")

    @torch.no_grad()
    def test(self, test_dataloader, split_name, ith_layer_inference=None):
        if self.local_rank == 0:
            print("-------- Testing --------")
        filename = (
            "model_best_val_top_24.pt"
            if self.model_output_rep in ["elbo", "weighted_sum_ensemble"]
            else "model_best_val_top_1.pt"
        )
        second_best_filename = filename.replace("best", "second_best")
        best_ckpt_path = os.path.join(self.cfg.ckpt_save_path, filename)
        best24ckpt = torch.load(best_ckpt_path, map_location=self.device)
        self.model.module.load_state_dict(best24ckpt["model"])
        print(f"Loaded weights for the model with the WER score: {best24ckpt['best_score']}")
        # saved for the case best_ckpt was corrupted while saving, can be deleted after training ends
        if Path(second_best_filename).exists():
            os.remove(second_best_filename)
        self.model.eval()
        discrip_str = f"Epoch-{self.current_epoch}/{self.epoch}"
        pbar_test = tqdm(
            test_dataloader,
            disable=self.local_rank != 0,
            dynamic_ncols=True,
        )
        pbar_test.set_description("Testing" + discrip_str)
        self.reset_meters()
        self.reset_recoders()
        for data in pbar_test:
            pbar_test_dic = self._finetune_val_step(data, ith_layer_inference)
            pbar_test.set_postfix(pbar_test_dic)
        epoch_token_target = self.gather_distributed_data(self.target_token_recoder.data, mode="target")
        epoch_word_target = self.gather_distributed_data(self.target_word_recoder.data, mode="target")
        # waveforms = self.gather_distributed_data(self.waveform_recoder.data)
        # filenames = self.gather_distributed_data(self.filenames_recoder.data)
        if self.model_output_rep in ["elbo", "weighted_sum_ensemble"]:
            epoch_token_preds = [
                self.gather_distributed_data(recoder.data)
                for recoder in self.predict_token_recoder[: len(self.topk_dict.keys())]
            ]
            epoch_word_preds = [
                self.gather_distributed_data(recoder.data)
                for recoder in self.predict_word_recoder[: len(self.topk_dict.keys())]
            ]
            if self.model_output_rep == "elbo":
                self.kl_loss_meter.sync_distributed()
                self.ctc_loss_meter.sync_distributed()
        else:
            epoch_token_preds = [self.gather_distributed_data(self.predict_token_recoder[0].data)]
            epoch_word_preds = [self.gather_distributed_data(self.predict_word_recoder[0].data)]
        self.loss_meter.sync_distributed()
        torch.distributed.barrier()
        if self.local_rank == 0:
            log_dict = {}
            test_loss = self.loss_meter.avg
            log_dict[f"{split_name}/total_loss"] = test_loss
            if self.model_output_rep == "elbo":
                if not self.cfg.train.disable_kl_loss:
                    kl_loss = self.kl_loss_meter.avg
                    log_dict[f"{split_name}/kl_loss"] = kl_loss
                ctc_loss = self.ctc_loss_meter.avg
                log_dict[f"{split_name}/ctc_loss"] = ctc_loss
                log_dict[f"{split_name}/kl_loss"] = kl_loss
            cer, wer = [], []
            if self.model_output_rep in ["elbo", "weighted_sum_ensemble"]:
                for i, pw in enumerate(epoch_word_preds):
                    _cer, _wer = self.calculate_score(epoch_token_preds[i], pw, epoch_token_target, epoch_word_target)
                    cer.append(_cer)
                    wer.append(_wer)
                    log_dict[f"{split_name}/cer{self.topk_dict[i]}"] = _cer
                    log_dict[f"{split_name}/wer{self.topk_dict[i]}"] = _wer
            else:
                _cer, _wer = self.calculate_score(
                    epoch_token_preds[0],
                    epoch_word_preds[0],
                    epoch_token_target,
                    epoch_word_target,
                )
                cer.append(_cer)
                wer.append(_wer)
                log_dict[f"{split_name}/cer1"] = _cer
                log_dict[f"{split_name}/wer1"] = _wer

            # for i, wav in enumerate(waveforms):
            #     log_dict[f"audio_test/{filenames[i]}"] = wandb.Audio(wav.cpu().squeeze(0), caption=filenames[i],
            #                                                          sample_rate=16000)

            columns = ["Target Transcription", "Predicted Transcription"]
            data = []
            predicted_transcription = epoch_word_preds[self.target_index]
            for tt, pt in zip(epoch_word_target, predicted_transcription):
                tt = " ".join(tt)
                pt = " ".join(pt)
                data.append([tt, pt])
            table = wandb.Table(data=data, columns=columns)
            log_dict[f"{split_name}_examples"] = table
            wandb.log(log_dict, step=self.iteration)
            # df = pd.DataFrame(df_dict).reset_index(drop=True)
            # df.to_csv(os.path.join(self.cfg.workshop, "test_predictions.csv"), index=False)

    def model_save(self, best_score, is_best=False, idx=0, filename="last_checkpoint.pt"):
        save_dict = {
            "cfg": self.cfg,
            "epoch": self.current_epoch,
            "iteration": self.iteration,
            "model": self.model.module.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "scheduler": self.scheduler.state_dict(),
            "best_score": best_score,
        }
        if is_best:
            k = self.topk_dict[idx]
            best_ckpt_filepath = f"{self.cfg.ckpt_save_path}/model_best_val_top_{k}.pt"
            second_best_ckpt_filepath = f"{self.cfg.ckpt_save_path}/model_second_best_val_top_{k}.pt"
            if Path(best_ckpt_filepath).exists():
                os.rename(best_ckpt_filepath, second_best_ckpt_filepath)
            torch.save(save_dict, best_ckpt_filepath)

    def run(self):
        self.prepare_staff()
        while self.current_epoch < self.epoch:
            self.train_epoch()
            if self.current_epoch % self.wandb_val_epoch_interval == 0:
                self.evaluate()
            self.current_epoch += 1
            if self.early_stopping > self.patience:
                dist.barrier()
                print(f"Early stopping (patience: {self.early_stopping})")
                break
        if self.cfg.dataset.have_test_set:
            print("Having test set!")
        if isinstance(self.dataloader_test, list):
            for i, dataloader in enumerate(self.dataloader_test):
                self.test(dataloader, split_name=self.cfg.dataset.test_folder[i])
        self.cleanup()

    def cleanup(self):
        if self.logger_train is not None:
            asr.utils.logger.close_logger(self.logger_train)
        if self.logger_val is not None:
            asr.utils.logger.close_logger(self.logger_val)
        torch.cuda.empty_cache()

        self.current_epoch = 0
        self.iteration = 0
        self.best_score = 1e6
