import json
import math
import os
from collections import OrderedDict
from pathlib import Path

import accelerate
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
import torch.optim.lr_scheduler as lr_scheduler
import wandb
from configs import dict_2_list
from matplotlib import pyplot as plt
from modules.aamsoftmax import AAMSoftmax
from scipy import stats
from torch.nn import functional as F
from tqdm import tqdm
from utils.metric import calculate_score_classification, get_eer, get_min_c


class Engine:
    def __init__(self, cfg, mode: str, local_rank: int, world_size: int):
        self.cfg = cfg
        self.topk_dict = {0: "1", 1: "24_weight_scores", 2: "24_weight_hiddens"}
        self.clip_grad = cfg.train.clip_grad
        self.prior_distribution_name = cfg.model.prior_distribution
        self.kl_gamma = cfg.train.kl_gamma
        if self.clip_grad:
            self.clip_grad_value = cfg.train.clip_grad_value
        self.local_rank = local_rank
        self.world_size = world_size
        self.device = self.cfg.train.device
        self.EPOCH = self.cfg.train.EPOCH
        self.current_epoch = 0
        self.iteration = 0
        self.model_output_rep = self.cfg.model.output_rep
        if self.model_output_rep == "elbo":
            self.target_idx = 1
            if cfg.model.prior_distribution == "uniform":
                self.target_distribution = (
                    torch.ones(cfg.train.batch_size, cfg.model.encoder_layers, device=self.device)
                    / cfg.model.encoder_layers
                )
            elif cfg.model.prior_distribution == "chi2":
                self.target_distribution = self._get_chi2_distribution()
            else:
                print(f"Unknown prior distribution: {cfg.model.prior_distribution}")
        else:
            self.target_idx = 0
        self.best_score = [1e4] * 3
        self.best_min_c = [1e4] * 3
        self.lowest_loss = None
        self.lowest_loss_val = None
        self.wandb_project = self.cfg.train.wandb_project
        self.wandb_train_step_log_interval = self.cfg.train.wandb_train_step_log_interval
        self.wandb_val_epoch_interval = self.cfg.train.wandb_val_epoch_interval
        self.mode = mode
        self.early_stopping = torch.zeros(1).to(self.device)

        self.dataloader_factory = sv.utils.dataset.DataloaderFactory(self.cfg.dataset, mode)
        self.log_plot_dir = Path(self.cfg.workshop) / "distribution_predicted"
        ce_reduction = "none" if self.model_output_rep == "elbo" else "mean"
        if self.cfg.train.objective == "ce":
            self.loss_func = torch.nn.CrossEntropyLoss(reduction=ce_reduction)
        else:
            self.loss_func = AAMSoftmax(
                speaker_embedding_dim=self.cfg.model.projector_dim,
                n_speakers=self.cfg.model.num_classes,
                reduction=ce_reduction,
                output_rep=self.cfg.model.output_rep,
                encoder_layers=self.cfg.model.encoder_layers,
                margin=self.cfg.train.margin,
                scale=self.cfg.train.scale,
            ).to(self.device)
        self.score_fn = torch.nn.CosineSimilarity(dim=-1)
        ### prepare meters
        self.loss_meter = sv.utils.avgmeter.AverageMeter(device=self.device)
        self.kl_loss_meter = sv.utils.avgmeter.AverageMeter(device=self.device)
        self.ce_loss_meter = sv.utils.avgmeter.AverageMeter(device=self.device)
        if self.model_output_rep == "elbo":
            n = 24
        else:
            n = 1
        self.acc_meter = [sv.utils.avgmeter.AverageMeter(device=self.device) for i in range(n)]
        self.predict_recoder = [sv.utils.recoder.TensorRecorder(device=self.device, dtype=torch.int64) for i in range(n)]
        self.label_recoder = sv.utils.recoder.TensorRecorder(device=self.device, dtype=torch.int64)
        self.predict_layer_distribution = sv.utils.recoder.ArrayRecorder()

        if self.local_rank == 0:
            wandb.init(project=self.wandb_project, mode=self.cfg.train.wandb_mode)
            print("Main pid:", os.getpid())

    def _get_chi2_distribution(self):
        lower_bound = stats.ncx2.ppf(0.01, self.cfg.model.chi2_df, self.cfg.model.chi2_nc)
        upper_bound = stats.ncx2.ppf(0.99, self.cfg.model.chi2_df, self.cfg.model.chi2_nc)
        x, step = np.linspace(lower_bound, upper_bound, self.cfg.model.encoder_layers, retstep=True)
        p = stats.ncx2.pdf(x, self.cfg.model.chi2_df, self.cfg.model.chi2_nc)
        probas = p * step
        probas_sum = sum(probas)
        probas[0] -= probas_sum - 1
        probas = list(reversed(probas))
        tensor_probas = (torch.tensor(probas, device=self.device).repeat(self.cfg.train.batch_size)).reshape(
            self.cfg.train.batch_size, self.cfg.model.encoder_layers
        )
        return tensor_probas

    def kl_with_temperature(self, s_logits, t_logits, temperature=1, two_sided: bool = False, reduction="batchmean"):
        if t_logits.size(-1) > 1:
            if two_sided:
                distillation_loss = (
                    self.kl_with_temperature(s_logits, t_logits, temperature, reduction=reduction)
                    + self.kl_with_temperature(t_logits, s_logits, temperature, reduction=reduction)
                ) / 2
            else:
                distillation_loss = F.kl_div(
                    torch.log_softmax(s_logits / temperature, dim=-1),
                    t_logits,
                    reduction=reduction,
                )  # * temperature**2
        else:
            distillation_loss = F.mse_loss(s_logits, t_logits, reduction=reduction)
        return distillation_loss

    def elbo_loss_func(self, predictions, layer_distribution_logits, y):
        ce_loss = []
        all_predicted_classes = []
        bs = layer_distribution_logits.shape[0]
        layer_distribution_probas = F.softmax(layer_distribution_logits, dim=-1)
        for i, prediction in enumerate(predictions):
            if self.cfg.train.objective == "ce":
                loss = self.loss_func(prediction, y)
                predicted_classes = [torch.argmax(prediction, dim=1)]
            else:
                # print(f"prediction.shape: {prediction.shape}")
                loss, predicted_classes = self.loss_func(prediction, y, idx=i)
            all_predicted_classes.extend(predicted_classes)
            # print(f"layer_distribution_probas.shape: {layer_distribution_probas.shape}")
            # print(f"i: {i}")
            ce_loss.append(layer_distribution_probas[:, i] * loss)
        assert len(predictions) == 24, "Must compute CE per layer"
        ce_loss = torch.sum(torch.stack(ce_loss), dim=0)  # sum per layer, average per sample
        assert len(ce_loss) == bs, f"Expected to sum ce loss per layer before averaging, got: {len(ce_loss)}"
        ce_loss = torch.mean(ce_loss)
        target_distribution = self.target_distribution[:bs, :]
        kl_loss = self.kl_with_temperature(layer_distribution_logits, target_distribution, two_sided=False)
        total_loss = self.kl_gamma * kl_loss + ce_loss
        return total_loss, kl_loss, ce_loss, all_predicted_classes

    def prepare_staff(self):
        """We move this part out of the __init__ function to avoid the weird error:
        DataLoader worker (pid xxx) is killed by signal: Aborted
        This error is probably caused by a conflict between lmdb and ddp.
        """
        ### prepare dataloader
        self.dataloader_train = self.dataloader_factory.build(state="train", bs=self.cfg.train.batch_size)
        self.dataloader_val = self.dataloader_factory.build(state="val", bs=self.cfg.train.val_batch_size)
        self.dataloader_test = self.dataloader_factory.build(state="test", bs=self.cfg.train.val_batch_size)

        ### prepare model, optimizer and scheduler
        self.cfg.model.freeze_cnn = self.cfg.train.freeze_cnn
        self.cfg.model.freeze_upstream = self.cfg.train.freeze_upstream
        self.cfg.model.freeze_cnn = self.cfg.train.freeze_cnn
        model = sv.models.wavlm.WavLMFinetuneWrapper(self.cfg.model, self.cfg.train).to(self.device)

        if self.cfg.train.freeze_cnn:
            for param in model.wavlm.feature_extractor.parameters():
                param.requires_grad = False
        if self.cfg.train.freeze_upstream:
            print("WavLM upstream layers on FREEZE!")
            for param in model.wavlm.parameters():
                param.requires_grad = False

        find_unused_parameters = True if self.cfg.model.output_rep.startswith("layer") else False
        accelerator = accelerate.Accelerator(mixed_precision="bf16")
        model = accelerator.prepare(model)
        self.model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[self.local_rank], find_unused_parameters=find_unused_parameters
        )
        # from itertools import chain
        # params = chain(self.model.parameters(), self.loss_func.parameters())
        # params = list(filter(lambda x: x.requires_grad, self.model.parameters()))
        # if self.cfg.train.objective != "ce":
        #     loss_params = list(filter(lambda x: x.requires_grad, self.loss_func.parameters()))
        #     params = params + loss_params
        params = [
            {"params": self.model.parameters(), "lr": self.cfg.train.lr, "weight_decay": self.cfg.train.weight_decay},
            {"params": self.loss_func.parameters(), "lr": 1e-1, "weight_decay": 0.0, "betas": (0.7, 0.999)},
        ]
        self.optimizer = torch.optim.AdamW(params)
        # self.optimizer = torch.optim.AdamW(
        #     params=params,
        #     lr=self.cfg.train.lr,
        #     weight_decay=0. #self.cfg.train.weight_decay,
        # )

        if self.local_rank == 0:
            print(f"Optimizer: {self.cfg.train.optimizer}")

        # CosineAnnealingLR with Warm-up
        warmup_epoch = int(self.cfg.train.warmup_epoch * self.EPOCH) if self.mode == "_pretrain" else 0
        lr_max = self.cfg.train.lr
        lr_min = self.cfg.train.lr * 0.01
        T_max = self.EPOCH
        lr_lambda = lambda epoch: (
            (epoch + 1) / warmup_epoch
            if epoch < warmup_epoch
            else (
                lr_min
                + 0.5 * (lr_max - lr_min) * (1.0 + math.cos((epoch - warmup_epoch) / (T_max - warmup_epoch) * math.pi))
            )
            / self.cfg.train.lr
        )
        self.scheduler = lr_scheduler.LambdaLR(optimizer=self.optimizer, lr_lambda=lr_lambda)

        if self.cfg.train.load_model is not None:
            ckpt = torch.load(self.cfg.train.load_model, map_location=self.device)
            self.model.module.load_state_dict(ckpt["model"])
            if self.local_rank == 0:
                print(f"Loading model from {self.cfg.train.load_model}")
            del ckpt

        if self.cfg.train.resume is not None:
            ckpt = torch.load(self.cfg.train.resume, map_location=self.device)
            self.model.module.load_state_dict(ckpt["model"])
            self.optimizer.load_state_dict(ckpt["optimizer"])
            self.scheduler.load_state_dict(ckpt["scheduler"])
            self.scheduler.step()
            self.current_epoch = ckpt["epoch"] + 1
            self.iteration = ckpt["iteration"]
            self.best_score[self.target_idx] = ckpt["best_score"]
            if self.local_rank == 0:
                print(f"Resuming from {self.cfg.train.resume}")
            del ckpt

        ### prepare writer and logger
        if self.local_rank == 0:
            # wandb.watch(self.model, log="all", log_graph=True)
            self.logger_train = sv.utils.logger.create_logger(self.cfg.workshop, name="train")
            self.logger_train.info(f"workshop: {self.cfg.workshop}")
            self.logger_train.info(f"seed: {self.cfg.train.seed}")
            self.logger_train.info(f"pid: {os.getpid()}")

            self.logger_val = sv.utils.logger.create_logger(self.cfg.workshop, name="val")
            self.logger_val.info(f"workshop: {self.cfg.workshop}")
            self.logger_val.info(f"seed: {self.cfg.train.seed}")
            self.logger_val.info(f"pid: {os.getpid()}")
        else:
            self.logger_train = None
            self.logger_val = None

        self.config_2_json()

    def config_2_json(self, jsonfile=None):
        self.jsonfile = os.path.join(self.cfg.workshop, "config.json") if jsonfile is None else jsonfile
        with open(self.jsonfile, "w") as f:
            json.dump(dict(self.cfg), f, indent=2)

    def json_2_config(self, jsonfile=None):
        if jsonfile is not None:
            self.jsonfile = jsonfile
        assert hasattr(self, "jsonfile"), "Please provide the .json file first."
        with open(self.jsonfile, "r") as f:
            data = json.load(f)
            self.cfg.merge_from_list(dict_2_list(data))

    def reset_meters(self):
        self.loss_meter.reset()
        self.kl_loss_meter.reset()
        self.ce_loss_meter.reset()
        for meter in self.acc_meter:
            meter.reset()

    def reset_recoders(self):
        for meter in self.predict_recoder:
            meter.reset()
        self.label_recoder.reset()
        self.predict_layer_distribution.reset()

    def gather_distributed_data(self, gather_data):
        if isinstance(gather_data, torch.Tensor):
            _output = [torch.zeros_like(gather_data) for _ in range(self.world_size)]
            dist.all_gather(_output, gather_data, async_op=False)
            output = torch.cat(_output)
        else:
            if gather_data[0] is not None:
                _output = [None for _ in range(self.world_size)]
                if hasattr(dist, "all_gather_object"):
                    dist.all_gather_object(_output, gather_data)
                else:
                    sv.utils.distributed.all_gather_object(_output, gather_data, self.world_size)
                output = []
                for lst in _output:
                    output.extend(lst)
            else:
                output = None
        return output

    def _finetune_step(self, data):
        waveform = torch.cat(data["waveform"], dim=0).to(self.device)
        padding_mask = torch.cat(data["padding_mask"], dim=0).to(self.device)
        y = torch.cat(data["label"], dim=0).to(self.device)
        batch_size = y.shape[0]
        self.optimizer.zero_grad()
        pred, layer_distribution = self.model(waveform, padding_mask)
        if self.cfg.model.output_rep != "elbo":
            if self.cfg.train.objective == "ce":
                loss = self.loss_func(pred[0], y)
                y_pred = [torch.argmax(pred[0], dim=1)]
            else:
                loss, y_pred = self.loss_func(pred[0], y)
            kl_loss, ce_loss = None, None
        else:
            loss, kl_loss, ce_loss, y_pred = self.elbo_loss_func(pred, layer_distribution, y)
        loss.backward()
        if self.clip_grad:
            grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_value)
        else:
            grad_norm = None

        self.optimizer.step()
        for i, yp in enumerate(y_pred):
            self.predict_recoder[i].record(yp)
        self.label_recoder.record(y)

        accuracy, f1_weighted, precision, recall = [], [], [], []
        for i, yp in enumerate(y_pred):
            acc, f1, prec, recc = calculate_score_classification(y.cpu(), yp.cpu())
            accuracy.append(acc)
            f1_weighted.append(f1)
            precision.append(prec)
            recall.append(recc)
        self.loss_meter.update(loss.item())
        if kl_loss:
            self.kl_loss_meter.update(kl_loss.item())
        if ce_loss:
            self.ce_loss_meter.update(ce_loss.item())
        for i, ac in enumerate(accuracy):
            self.acc_meter[i].update(ac, batch_size)
        log_dict = {}
        if self.iteration % self.wandb_train_step_log_interval == 0 and self.local_rank == 0:
            log_dict["train_step/total_loss"] = loss.item()
            log_dict["train_step/total_samples"] = self.iteration * batch_size * self.world_size
            if self.model_output_rep == "elbo":
                log_dict["train_step/kl_loss"] = kl_loss.item()
                log_dict["train_step/ce_loss"] = ce_loss.item()
            for i, ac in enumerate(accuracy):
                log_dict[f"train_accuracy/{i}_layer"] = ac
                log_dict[f"train_f1_weighted/{i}_layer"] = f1_weighted[i]
                log_dict[f"train_precision/{i}_layer"] = precision[i]
                log_dict[f"train_recall/{i}_layer"] = recall[i]
            if grad_norm:
                log_dict["grad_norm"] = grad_norm
            wandb.log(log_dict, step=self.iteration)

        pbar_train_dic = OrderedDict()
        pbar_train_dic["iter"] = self.iteration
        pbar_train_dic["lr"] = self.optimizer.param_groups[0]["lr"]
        for i in range(len(accuracy)):
            pbar_train_dic[f"acc{i}"] = f"{self.acc_meter[i].avg:.5f}"
        pbar_train_dic["loss"] = f"{self.loss_meter.avg:.5f}"
        if self.model_output_rep == "elbo":
            pbar_train_dic["kl_loss"] = f"{self.kl_loss_meter.avg:.5f}"
            pbar_train_dic[f"{self.cfg.train.objective}_loss"] = f"{self.ce_loss_meter.avg:.5f}"
        assert self.loss_meter.avg is not None, "Loss turned to be None"
        return pbar_train_dic

    def _get_weighted_hidden(self, layer_distribution, batch_size, hiddens):
        # best_topk_indices = torch.topk(input=layer_distribution, dim=1, k=k).indices
        best_topk_values = F.softmax(layer_distribution, dim=1)
        top_k_proba_predictions = torch.zeros((batch_size, self.cfg.model.projector_dim)).to(self.device)
        for i in range(batch_size):
            for j in range(24):
                weight = best_topk_values[i, j]
                top_k_proba_predictions[i, :] += hiddens[j][i] * weight
        return top_k_proba_predictions

    @staticmethod
    def _get_topk_values_indices(layer_distribution):
        layer_distribution1, layer_distribution2 = layer_distribution
        # best_topk_indices1 = torch.topk(input=layer_distribution1, dim=1, k=k).indices
        best_topk_values1 = F.softmax(layer_distribution1, dim=1)
        # best_topk_indices2 = torch.topk(input=layer_distribution2, dim=1, k=k).indices
        best_topk_values2 = F.softmax(layer_distribution2, dim=1)
        return best_topk_values1, best_topk_values2

    def _get_weighted_score(self, layer_distribution, batch_size, hiddens):
        hiddens1, hiddens2 = hiddens
        best_topk_values1, best_topk_values2 = self._get_topk_values_indices(layer_distribution)
        top_k_scores = torch.zeros(batch_size, device=self.device)
        for i in range(batch_size):
            for j in range(24):
                weight1 = best_topk_values1[i, j]
                weight2 = best_topk_values2[i, j]
                top_k_scores[i] += self.score_fn(hiddens1[j][i], hiddens2[j][i]) * (weight1 + weight2) / 2
        return top_k_scores

    def collect_top_k_predictions(self, layer_distribution, batch_size, hiddens):
        layer_distribution1, layer_distribution2 = layer_distribution
        hidden1, hidden2 = hiddens
        weighted_hidden1 = self._get_weighted_hidden(layer_distribution1, batch_size, hidden1)
        weighted_hidden2 = self._get_weighted_hidden(layer_distribution2, batch_size, hidden2)
        score = self.score_fn(weighted_hidden1, weighted_hidden2)
        return score

    @torch.no_grad()
    def _finetune_val_step(self, data, ith_layer_inference):
        waveform1 = data["waveform1"].to(self.device)
        waveform2 = data["waveform2"].to(self.device)
        padding_mask1 = data["padding_mask1"].to(self.device)
        padding_mask2 = data["padding_mask2"].to(self.device)
        y = data["label"].to(self.device)
        batch_size = y.shape[0]
        if ith_layer_inference is not None:
            hiddens1, layer_distribution1_logits = self.model.module.layer_inference(
                waveform1, padding_mask1, ith_layer_inference
            )
            hiddens2, layer_distribution2_logits = self.model.module.layer_inference(
                waveform2, padding_mask2, ith_layer_inference
            )
        else:
            hiddens1, layer_distribution_logits1 = self.model(waveform1, padding_mask1, inference=True)
            hiddens2, layer_distribution_logits2 = self.model(waveform2, padding_mask2, inference=True)
        if self.model_output_rep == "elbo":
            best_layer1 = torch.argmax(layer_distribution_logits1, dim=1)
            choosen_hiddens1 = torch.empty((batch_size, self.cfg.model.projector_dim), device=self.device)
            for batch, layer in enumerate(best_layer1):
                choosen_hiddens1[batch, :] = hiddens1[layer][batch]
            best_layer2 = torch.argmax(layer_distribution_logits2, dim=1)
            choosen_hiddens2 = torch.empty((batch_size, self.cfg.model.projector_dim), device=self.device)
            for batch, layer in enumerate(best_layer2):
                choosen_hiddens2[batch, :] = hiddens2[layer][batch]
            # TOP 1 Prediction
            score = self.score_fn(choosen_hiddens1, choosen_hiddens2)
            # TOP 3 Prediction
            # top_3_score = self._get_weighted_score(
            #     (layer_distribution_logits1, layer_distribution_logits2), 3, batch_size, (hiddens1, hiddens2)
            # )
            # # TOP 5 Prediction
            # top_5_score = self._get_weighted_score(
            #     (layer_distribution_logits1, layer_distribution_logits2), 5, batch_size, (hiddens1, hiddens2)
            # )
            # # TOP 12 Prediction
            # top_12_score = self._get_weighted_score(
            #     (layer_distribution_logits1, layer_distribution_logits2), 12, batch_size, (hiddens1, hiddens2)
            # )
            # TOP 24 Prediction
            top_24_score = self._get_weighted_score(
                (layer_distribution_logits1, layer_distribution_logits2), batch_size, (hiddens1, hiddens2)
            )
            top_24_score_weight_hiddens = self.collect_top_k_predictions(
                (layer_distribution_logits1, layer_distribution_logits2), batch_size, (hiddens1, hiddens2)
            )

        else:
            score = self.score_fn(hiddens1[0], hiddens2[0])
            top_3_score, top_24_score, top_24_score_weight_hiddens = None, None, None
        if self.model_output_rep == "elbo":
            for i, s in enumerate([score, top_24_score, top_24_score_weight_hiddens]):
                if s is not None:
                    self.predict_recoder[i].record(s)
                else:
                    print(f"Prediction for top_{self.topk_dict[i]} is None!")
                    assert False
        else:
            for i, s in enumerate([score]):
                if s is not None:
                    self.predict_recoder[i].record(s)
                else:
                    print(f"Prediction for top_{self.topk_dict[i]} is None!")
                    assert False

        self.label_recoder.record(y)
        if self.model_output_rep in ["elbo", "weighted_hiddens"]:
            layer_distribution_probas1 = F.softmax(layer_distribution_logits1, dim=1)
            layer_distribution_probas2 = F.softmax(layer_distribution_logits2, dim=1)
            self.predict_layer_distribution.record(layer_distribution_probas1.detach().cpu())
            self.predict_layer_distribution.record(layer_distribution_probas2.detach().cpu())
        return

    def train_epoch(self):
        torch.cuda.empty_cache()
        self.dataloader_train.set_epoch(self.current_epoch)
        if self.local_rank == 0:
            print(f"-------- {self.cfg.workshop} --------")
        discrip_str = f"Epoch-{self.current_epoch}/{self.EPOCH}"
        pbar_train = tqdm(self.dataloader_train, disable=self.local_rank != 0, dynamic_ncols=True)
        pbar_train.set_description("Train" + discrip_str)
        self.reset_meters()
        self.reset_recoders()

        self.model.train()
        for data in pbar_train:
            self.iteration += 1
            pbar_train_dic = self._finetune_step(data)
            pbar_train.set_postfix(pbar_train_dic)
        epoch_preds = [self.gather_distributed_data(recoder.data).cpu() for recoder in self.predict_recoder]
        epoch_labels = self.gather_distributed_data(self.label_recoder.data).cpu()
        self.loss_meter.sync_distributed()
        if self.model_output_rep == "elbo":
            self.kl_loss_meter.sync_distributed()
            self.ce_loss_meter.sync_distributed()
        if self.local_rank == 0:
            epoch_loss = self.loss_meter.avg
            if self.model_output_rep == "elbo":
                epoch_kl_loss = self.kl_loss_meter.avg
                epoch_ce_loss = self.ce_loss_meter.avg
            accuracy, recall, f1, precision = [], [], [], []
            for pred in epoch_preds:
                ac, _f1, pr, rec = calculate_score_classification(pred, epoch_labels)
                accuracy.append(ac)
                recall.append(rec)
                f1.append(_f1)
                precision.append(pr)
            log_dict = {}
            for i, ac, _f1 in zip(range(len(f1)), accuracy, f1):
                log_dict[f"train_epoch_accuracy/{i + 1}_layer"] = ac
                log_dict[f"train_epoch_f1/{i + 1}_layer"] = _f1
                log_dict[f"train_epoch_precision/{i + 1}_layer"] = precision[i]
                log_dict[f"train_epoch_recall/{i + 1}_layer"] = recall[i]
            log_dict["train_epoch/total_loss"] = epoch_loss
            if self.model_output_rep == "elbo":
                log_dict["train_epoch/kl_loss"] = epoch_kl_loss
                log_dict["train_epoch/ce_loss"] = epoch_ce_loss
            log_dict["LR"] = self.optimizer.param_groups[0]["lr"]
            wandb.log(log_dict, step=self.iteration)

    def evaluate(self, ith_layer_inference, save_model=True):
        assert self.dataloader_val is not None, self.logger_val.info("Validation dataloader is None")
        if self.local_rank == 0:
            print("-------- Validation --------")
        discrip_str = f"Epoch-{self.current_epoch}/{self.EPOCH}"
        pbar_val = tqdm(self.dataloader_val, disable=self.local_rank != 0, dynamic_ncols=True)
        pbar_val.set_description("Validation" + discrip_str)
        self.reset_meters()
        self.reset_recoders()
        self.model.eval()
        for data in pbar_val:
            self._finetune_val_step(data, ith_layer_inference=ith_layer_inference)
        epoch_preds = [self.gather_distributed_data(recoder.data).cpu() for recoder in self.predict_recoder]
        if self.model_output_rep == "elbo":
            epoch_preds = epoch_preds[: len(self.topk_dict.keys())]
        # else:
        #     epoch_preds = epoch_preds[:1]
        if self.model_output_rep in ["elbo", "weighted_hiddens"]:
            epoch_layer_dist_preds = self.gather_distributed_data(self.predict_layer_distribution.data)
        epoch_labels = self.gather_distributed_data(self.label_recoder.data).cpu()
        log_dict = {}
        if self.local_rank == 0:
            print(f"Computing Validation metrics on: {len(epoch_labels)} samples.")
            # Calculate EER min_c
            eers, min_cs = [], []
            for i, epoch_pred in enumerate(epoch_preds):
                eer = get_eer(scores=epoch_pred, labels=epoch_labels.numpy())
                min_c = get_min_c(scores=epoch_pred, labels=epoch_labels.numpy())
                eers.append(eer)
                min_cs.append(min_c)
            is_best = []
            for i, eer in enumerate(eers):
                # EARLY STOPPING
                if i == self.target_idx:
                    if eer >= self.best_score[i]:
                        self.early_stopping += 1
                    else:
                        self.early_stopping = torch.zeros(1).to(self.device)
                    print(f"Early stopping index: {self.early_stopping} / 3")
                log_dict[f"val/eer_top{self.topk_dict[i]}"] = eer
                log_dict[f"val/min_c_top{self.topk_dict[i]}"] = min_cs[i]
                is_best.append(eer < self.best_score[i])
                self.best_score[i] = min(self.best_score[i], eer)
                self.best_min_c[i] = min(self.best_min_c[i], min_cs[i])
                log_dict[f"best/eer_top{self.topk_dict[i]}"] = self.best_score[i]
                log_dict[f"best/min_c_top{self.topk_dict[i]}"] = self.best_min_c[i]

                if self.model_output_rep in ["elbo", "weighted_hiddens"]:
                    epoch_layer_dist_preds = np.stack(epoch_layer_dist_preds, axis=0)
                    plot_path = self._get_plot_distribution(
                        epoch_layer_dist_preds, plot_prior=self.model_output_rep == "elbo"
                    )
                    log_dict[f"Validation predictions/epoch_{self.current_epoch}"] = wandb.Image(str(plot_path))

            wandb.log(log_dict, step=self.iteration)
            if save_model:
                print("Started model saving...")
                for i, best in enumerate(is_best):
                    self.model_save(best, idx=i)
                print("Ended model saving...")
        dist.broadcast(self.early_stopping, src=0)
        print(f"Broadcast early stopping value to rank {self.local_rank}: {self.early_stopping} / 3")

    def _get_plot_distribution(self, distribution: np.array, plot_prior: bool):
        x = list(range(1, self.cfg.model.encoder_layers + 1))
        posterior_mean = np.mean(distribution, axis=0)
        posterior_std = np.std(distribution, axis=0)
        if plot_prior:
            if self.cfg.model.prior_distribution == "learnable":
                target_distribution = F.softmax(self.target_distribution)
            else:
                target_distribution = self.target_distribution
            prior_mean = (
                target_distribution.tolist()
                if len(self.target_distribution.shape) == 1
                else self.target_distribution[0, :].tolist()
            )
        fig, ax = plt.subplots(figsize=(12, 4))
        ax.errorbar(x, posterior_mean, yerr=posterior_std, linestyle="None", marker="^")
        if plot_prior:
            ax.errorbar(x, prior_mean, yerr=[0] * self.cfg.model.encoder_layers, linestyle="None", marker="*")
        for i in range(len(x)):
            ax.annotate(
                f"{posterior_mean[i]:.2f}",
                (x[i], posterior_mean[i]),
                textcoords="offset points",
                xytext=(0, 10),
                ha="center",
            )
        ax.set_xticks(range(min(x), max(x) + 1, 1))
        plot_filename = Path(self.log_plot_dir) / f"epoch_{self.current_epoch}.png"
        fig.savefig(plot_filename)
        return plot_filename

    @torch.no_grad()
    def _test(self, filename, ith_layer_inference):
        if self.cfg.train.resume is None:
            best_ckpt_path = os.path.join(self.cfg.ckpt_save_path, filename)
        else:
            best_ckpt_path = self.cfg.train.resume
        best24ckpt = torch.load(best_ckpt_path, map_location=self.device)
        self.model.module.load_state_dict(best24ckpt["model"])
        print(f"Loaded weights for the model with the EER score: {best24ckpt['best_score']}")
        self.model.eval()
        discrip_str = f"Epoch-{self.current_epoch}/{self.EPOCH}"
        pbar_test = tqdm(self.dataloader_test, disable=self.local_rank != 0, dynamic_ncols=True)
        pbar_test.set_description("Testing" + discrip_str)
        self.reset_meters()
        self.reset_recoders()
        for data in pbar_test:
            self._finetune_val_step(data, ith_layer_inference=ith_layer_inference)
        test_preds = [self.gather_distributed_data(recoder.data).cpu() for recoder in self.predict_recoder]
        if self.model_output_rep == "elbo":
            test_preds = test_preds[: len(self.topk_dict.keys())]
        # else:
        #     test_preds = [test_preds[0]]
        test_labels = self.gather_distributed_data(self.label_recoder.data).cpu()
        log_dict = {}
        if self.local_rank == 0:
            print(f"Computing Test metrics on: {len(test_labels)} samples.")
            eers, min_cs = [], []
            for i, epoch_pred in enumerate(test_preds):
                eer = get_eer(scores=epoch_pred, labels=test_labels.numpy())
                min_c = get_min_c(scores=epoch_pred, labels=test_labels.numpy())
                eers.append(eer)
                min_cs.append(min_c)
            df_dict = {"label": test_labels}
            for i, eer in enumerate(eers):
                log_dict[f"test/eer_top{self.topk_dict[i]}"] = eer
                log_dict[f"test/min_c_top{self.topk_dict[i]}"] = min_cs[i]
                df_dict[f"top{self.topk_dict[i]}"] = test_preds[i]
            wandb.log(log_dict, step=self.iteration)
            df = pd.DataFrame(df_dict).reset_index(drop=True)
            if self.cfg.train.resume is None:
                csv_save_path = self.cfg.workshop
            else:
                csv_save_path = str(Path(self.cfg.train.resume).parent)
            df.to_csv(os.path.join(csv_save_path, "test_predictions.csv"), index=False)

    @torch.no_grad()
    def test(self, ith_layer_inference=None):
        if self.model_output_rep != "elbo":
            filename = "model_best_val_top_1.pt"
            self._test(filename, ith_layer_inference=ith_layer_inference)
        else:
            for filename in ["model_best_val_top_24_weight_scores.pt"]:  # , "model_best_val_top_24_weight_hiddens.pt"
                self._test(filename, ith_layer_inference=ith_layer_inference)

    def model_save(self, is_best=False, idx=0, filename="last_checkpoint.pt"):
        save_dict = {
            "cfg": self.cfg,
            "epoch": self.current_epoch,
            "iteration": self.iteration,
            "model": self.model.module.state_dict(),  # save DDP model
            "optimizer": self.optimizer.state_dict(),
            "scheduler": self.scheduler.state_dict(),
        }
        if self.mode == "_finetune":
            save_dict["best_score"] = self.best_score[idx]
        else:
            save_dict["lowest_loss"] = self.lowest_loss_val
        if is_best:
            k = self.topk_dict[idx]
            torch.save(save_dict, os.path.join(self.cfg.ckpt_save_path, f"model_best_val_top_{k}.pt"))

    @torch.no_grad()
    def collect_predictions(self):
        if self.local_rank == 0:
            self.model.eval()
            discrip_str = "Collecting eval predictions"
            pbar_val = tqdm(self.dataloader_val, dynamic_ncols=True)
            pbar_val.set_description(discrip_str)
            # Create draft for dictionary
            result_df = {"audio": [], "label": [], f"score_elbo24": []}
            for i in range(self.cfg.model.encoder_layers):
                result_df[f"layer_weight_{i}"] = []
            for data in pbar_val:
                waveform1 = data["waveform1"].to(self.device)
                waveform2 = data["waveform2"].to(self.device)
                padding_mask1 = data["padding_mask1"].to(self.device)
                padding_mask2 = data["padding_mask2"].to(self.device)
                y = data["label"].to(self.device)
                batch_size = y.shape[0]
                hiddens1, layer_distribution_logits1 = self.model(waveform1, padding_mask1, inference=True)
                hiddens2, layer_distribution_logits2 = self.model(waveform2, padding_mask2, inference=True)
                top_24_score = self._get_weighted_score(
                    (layer_distribution_logits1, layer_distribution_logits2), batch_size, (hiddens1, hiddens2)
                )
                layer_distribution_probas1 = F.softmax(layer_distribution_logits1, dim=1)
                for b_idx in range(batch_size):
                    result_df["audio"].append(waveform1[b_idx].squeeze(0).tolist())
                    result_df["label"].append(int(y[b_idx]))
                    result_df["score_elbo24"].append(float(top_24_score[b_idx]))
                    for i in range(self.cfg.model.encoder_layers):
                        result_df[f"layer_weight_{i}"].append(float(layer_distribution_probas1[b_idx, i]))
            df = pd.DataFrame(result_df)
            df.to_csv(f"{self.cfg.workshop}/logits_and_weights.csv", index=False)

    def run(self):
        self.prepare_staff()
        while self.current_epoch < self.EPOCH:
            self.train_epoch()
            self.scheduler.step()
            if self.current_epoch % self.wandb_val_epoch_interval == 0:
                self.evaluate(save_model=self.cfg.train.save_best_model, ith_layer_inference=None)
            self.current_epoch += 1
            if self.early_stopping > 2:
                dist.barrier()
                print(f"Early stopping (patience: {self.early_stopping})")
                break
        if self.cfg.dataset.have_test_set:
            print("Having test set!")
            self.test()
        self.cleanup()

    def cleanup(self):
        if self.logger_train is not None:
            sv.utils.logger.close_logger(self.logger_train)
        if self.logger_val is not None:
            sv.utils.logger.close_logger(self.logger_val)
        torch.cuda.empty_cache()

        self.current_epoch = 0
        self.iteration = 0
        self.best_score = 0 if self.mode == "_finetune" else None
        self.lowest_loss = 1e4 if self.mode == "_pretrain" else None
        self.lowest_loss_val = 1e4 if self.mode == "_pretrain" else None

        if not self.cfg.train.save_best:
            if hasattr(self, "ckpt_save_file") and os.path.exists(self.ckpt_save_file):
                os.remove(self.ckpt_save_file)
            if hasattr(self, "ckpt_best_file") and os.path.exists(self.ckpt_best_file):
                os.remove(self.ckpt_best_file)
