import json
import math
import os
import shutil
from collections import OrderedDict
from pathlib import Path

import accelerate
import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
import torch.optim.lr_scheduler as lr_scheduler
import wandb
from ser.configs import dict_2_list
from matplotlib import pyplot as plt
from scipy import stats
from sklearn.metrics import accuracy_score
from torch.nn import functional as F
from tqdm import tqdm


class Engine:
    def __init__(self, cfg, mode: str, local_rank: int, world_size: int):
        self.cfg = cfg
        self.clip_grad = cfg.train.clip_grad
        self.prior_distribution_name = cfg.model.prior_distribution
        self.kl_gamma = cfg.train.kl_gamma
        if self.clip_grad:
            self.clip_grad_value = cfg.train.clip_grad_value
        self.local_rank = local_rank
        self.world_size = world_size
        self.device = self.cfg.train.device
        self.early_stopping = torch.zeros(1).to(self.device)
        self.epoch = self.cfg.train.epoch
        self.current_epoch = 0
        self.iteration = 0
        self.model_output_rep = self.cfg.model.output_rep
        if self.model_output_rep == "elbo":
            self.target_index = 4
            self.topk_dict = {0: "1", 1: "3", 2: "5", 3: "12", 4: "24"}
            if cfg.model.prior_distribution == "uniform":
                self.target_distribution = (
                        torch.ones(cfg.train.batch_size, cfg.model.encoder_layers, device=self.device)
                        / cfg.model.encoder_layers
                )
            elif cfg.model.prior_distribution == "geometric":
                self.target_distribution = (
                    torch.tensor(
                        self._init_geometric_distribution(p=cfg.model.p_for_geometric_pmf, k=cfg.model.encoder_layers),
                        device=self.device,
                    )
                    .repeat(cfg.train.batch_size)
                    .reshape(cfg.train.batch_size, cfg.model.encoder_layers)
                )
            elif cfg.model.prior_distribution == "chi2":
                self.target_distribution = self._get_chi2_distribution()
            else:
                print(f"Unknown prior distribution: {cfg.model.prior_distribution}")
        else:
            self.target_index = 0
            self.topk_dict = {0: "1"}
        self.best_score = [0] * 5
        self.best_scores = [0] * self.cfg.dataset.num_classes
        self.wandb_project = self.cfg.train.wandb_project
        self.wandb_train_step_log_interval = self.cfg.train.wandb_train_step_log_interval
        self.wandb_val_epoch_interval = self.cfg.train.wandb_val_epoch_interval
        self.mode = mode

        self.dataloader_factory = ser.wavlm.utils.dataset.DataloaderFactory(self.cfg.dataset, mode)
        weights = None
        self.classes_dict = None
        if mode == "_finetune" and self.cfg.train.ce_weights is not None:
            weights = torch.FloatTensor(self.cfg.train.ce_weights).to("cuda")
            self.classes_dict = dict(zip(cfg.dataset.classes[: cfg.dataset.num_classes], cfg.dataset.text_classes))
        ce_reduction = "none" if self.model_output_rep == "elbo" else "mean"
        self.loss_func = torch.nn.CrossEntropyLoss(weight=weights, reduction=ce_reduction)
        self.calculate_score = ser.wavlm.utils.metric.calculate_score_classification
        ### prepare meters
        self.loss_meter = ser.wavlm.utils.avgmeter.AverageMeter(device="cuda")
        self.kl_loss_meter = ser.wavlm.utils.avgmeter.AverageMeter(device="cuda")
        self.ce_loss_meter = ser.wavlm.utils.avgmeter.AverageMeter(device="cuda")

        if self.model_output_rep == "elbo":
            n = self.cfg.model.encoder_layers
        else:
            n = 1
        self.acc_meter = [ser.wavlm.utils.avgmeter.AverageMeter(device="cuda") for i in range(n)]
        self.predict_recoder = [ser.wavlm.utils.recoder.TensorRecorder(device="cuda", dtype=torch.int64) for i in
                                range(n)]
        self.label_recoder = ser.wavlm.utils.recoder.TensorRecorder(device="cuda", dtype=torch.int64)
        self.predict_layer_distribution = ser.wavlm.utils.recoder.ArrayRecorder()
        self.log_plot_dir = Path(self.cfg.workshop) / "distribution_predicted"
        if self.local_rank == 0:
            wandb.init(project=self.wandb_project, mode=self.cfg.train.wandb_mode)
            # wandb.mark_preempting()
            print("Main pid:", os.getpid())

    @staticmethod
    def _init_geometric_distribution(p, k):
        pmf = [(1 - p) ** i * p for i in range(24)]
        if sum(pmf) > 1:
            pmf[-1] -= sum(pmf) - 1
        elif sum(pmf) < 1:
            pmf[-1] += 1 - sum(pmf)
        return list(reversed(pmf))

    def _get_chi2_distribution(self):
        lower_bound = stats.ncx2.ppf(0.01, self.cfg.model.chi2_df, self.cfg.model.chi2_nc)
        upper_bound = stats.ncx2.ppf(0.99, self.cfg.model.chi2_df, self.cfg.model.chi2_nc)
        x, step = np.linspace(lower_bound, upper_bound, self.cfg.model.encoder_layers, retstep=True)
        p = stats.ncx2.pdf(x, self.cfg.model.chi2_df, self.cfg.model.chi2_nc)
        probas = p * step
        probas_sum = sum(probas)
        probas[0] -= probas_sum - 1
        probas = list(reversed(probas))
        tensor_probas = (torch.tensor(probas, device=self.device).repeat(self.cfg.train.batch_size)).reshape(
            self.cfg.train.batch_size, self.cfg.model.encoder_layers
        )
        return tensor_probas

    def kl_with_temperature(self, s_logits, t_logits, temperature=1, two_sided: bool = False, reduction="batchmean"):
        if t_logits.size(-1) > 1:
            if two_sided:
                distillation_loss = (
                                            self.kl_with_temperature(s_logits, t_logits, temperature,
                                                                     reduction=reduction)
                                            + self.kl_with_temperature(t_logits, s_logits, temperature,
                                                                       reduction=reduction)
                                    ) / 2
            else:
                distillation_loss = F.kl_div(
                    torch.log_softmax(s_logits / temperature, dim=-1),
                    t_logits,
                    reduction=reduction,
                )  # * temperature**2
        else:
            distillation_loss = F.mse_loss(s_logits, t_logits, reduction=reduction)
        return distillation_loss

    def elbo_loss_func(self, predictions, layer_distribution_logits, y, gamma=0.1, target_prior_dist="uniform"):
        ce_loss = []
        bs = layer_distribution_logits.shape[0]
        layer_distribution_probas = F.softmax(layer_distribution_logits, dim=-1)
        for i, prediction in enumerate(predictions):
            ce_loss.append(layer_distribution_probas[:, i] * self.loss_func(prediction, y))
        assert len(predictions) == 24, "Must compute CE per layer"
        ce_loss = torch.sum(torch.stack(ce_loss), dim=0)  # sum per layer, average per sample
        assert len(ce_loss) == bs, f"Expected to sum ce loss per layer before averaging, got: {len(ce_loss)}"
        ce_loss = torch.mean(ce_loss)
        target_distribution = self.target_distribution[:bs, :]
        # print(f"layer_distribution_logits.shape: {layer_distribution_logits.shape}, target_distribution: {target_distribution.shape}")
        kl_loss = self.kl_with_temperature(layer_distribution_logits, target_distribution, two_sided=False)
        total_loss = gamma * kl_loss + ce_loss
        return total_loss, kl_loss, ce_loss, layer_distribution_probas

    def prepare_staff(self):
        """We move this part out of the __init__ function to avoid the weird error:
        DataLoader worker (pid xxx) is killed by signal: Aborted
        This error is probably caused by a conflict between lmdb and ddp.
        """
        ### prepare dataloader
        self.dataloader_train = self.dataloader_factory.build(
            size=self.cfg.train.ds_size, state="train", bs=self.cfg.train.batch_size
        )
        self.dataloader_val = self.dataloader_factory.build(
            size=self.cfg.train.ds_size, state="val", bs=self.cfg.train.val_batch_size
        )
        self.dataloader_test = self.dataloader_factory.build(
            size=self.cfg.train.ds_size, state="test", bs=self.cfg.train.val_batch_size
        )

        ### prepare model, optimizer and scheduler
        self.cfg.model.freeze_cnn = self.cfg.train.freeze_cnn
        self.cfg.model.freeze_upstream = self.cfg.train.freeze_upstream
        if self.cfg.model.backbone == "wavlm":
            model = ser.wavlm.models.wavlm.WavLMFinetuneWrapper(self.cfg.model).to(self.device)
            if self.cfg.train.freeze_cnn:
                for param in model.wavlm.feature_extractor.parameters():
                    param.requires_grad = False
            if self.cfg.train.freeze_upstream:
                print("WavLM upstream layers on FREEZE!")
                for param in model.wavlm.parameters():
                    param.requires_grad = False
        elif self.cfg.model.backbone == "data2vec":
            model = ser.wavlm.models.wavlm.Data2VecFinetuneWrapper(self.cfg.model).to(self.device)
            ### This mask must not be initialized at all
            model.data2vec.model.masked_spec_embed.requires_grad = False
            if self.cfg.train.freeze_cnn:
                for param in model.data2vec.model.feature_extractor.parameters():
                    param.requires_grad = False
            if self.cfg.model.freeze_backbone:
                for param in model.data2vec.model.parameters():
                    param.requires_grad = False
        else:
            print(f"Backbone is not specified! Expected wavlm or data2vec, got: {self.cfg.model.backbone}")

        find_unused_parameters = True if self.cfg.model.output_rep.startswith("layer") else False
        accelerator = accelerate.Accelerator(mixed_precision="bf16")
        model = accelerator.prepare(model)
        self.model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[self.local_rank], find_unused_parameters=find_unused_parameters
        )
        self.optimizer = torch.optim.AdamW(
            params=filter(lambda x: x.requires_grad, self.model.parameters()),
            lr=self.cfg.train.lr,
            weight_decay=self.cfg.train.weight_decay,
        )

        if self.local_rank == 0:
            print(f"Optimizer: {self.cfg.train.optimizer}")

        # CosineAnnealingLR with Warm-up
        warmup_epoch = 0
        lr_max = self.cfg.train.lr
        lr_min = self.cfg.train.lr * 0.01
        T_max = self.epoch
        lr_lambda = lambda epoch: (
            (epoch + 1) / warmup_epoch
            if epoch < warmup_epoch
            else (
                         lr_min
                         + 0.5 * (lr_max - lr_min) * (
                                     1.0 + math.cos((epoch - warmup_epoch) / (T_max - warmup_epoch) * math.pi))
                 )
                 / self.cfg.train.lr
        )
        self.scheduler = lr_scheduler.LambdaLR(optimizer=self.optimizer, lr_lambda=lr_lambda)

        if self.cfg.train.resume is not None:
            ckpt = torch.load(self.cfg.train.resume, map_location=self.device)
            self.model.module.load_state_dict(ckpt["model"])
            self.optimizer.load_state_dict(ckpt["optimizer"])
            self.scheduler.load_state_dict(ckpt["scheduler"])
            self.scheduler.step()
            self.current_epoch = ckpt["epoch"] + 1
            self.iteration = ckpt["iteration"]
            self.best_score[self.target_index] = ckpt["best_score"]
            if self.local_rank == 0:
                print(f"Resuming from {self.cfg.train.resume}")
            del ckpt

        ### prepare writer and logger
        if self.local_rank == 0:
            # wandb.watch(self.model, log="all", log_graph=True)
            self.logger_train = ser.wavlm.utils.logger.create_logger(self.cfg.workshop, name="train")
            self.logger_train.info(f"workshop: {self.cfg.workshop}")
            self.logger_train.info(f"seed: {self.cfg.train.seed}")
            self.logger_train.info(f"pid: {os.getpid()}")

            self.logger_val = ser.wavlm.utils.logger.create_logger(self.cfg.workshop, name="val")
            self.logger_val.info(f"workshop: {self.cfg.workshop}")
            self.logger_val.info(f"seed: {self.cfg.train.seed}")
            self.logger_val.info(f"pid: {os.getpid()}")
        else:
            self.logger_train = None
            self.logger_val = None

        self.config_2_json()

    def config_2_json(self, jsonfile=None):
        self.jsonfile = os.path.join(self.cfg.workshop, "config.json") if jsonfile is None else jsonfile
        with open(self.jsonfile, "w") as f:
            json.dump(dict(self.cfg), f, indent=2)

    def json_2_config(self, jsonfile=None):
        if jsonfile is not None:
            self.jsonfile = jsonfile
        assert hasattr(self, "jsonfile"), "Please provide the .json file first."
        with open(self.jsonfile, "r") as f:
            data = json.load(f)
            self.cfg.merge_from_list(dict_2_list(data))

    def reset_meters(self):
        self.loss_meter.reset()
        self.kl_loss_meter.reset()
        self.ce_loss_meter.reset()
        for meter in self.acc_meter:
            meter.reset()

    def reset_recoders(self):
        for meter in self.predict_recoder:
            meter.reset()
        self.label_recoder.reset()
        self.predict_layer_distribution.reset()

    def gather_distributed_data(self, gather_data):
        if isinstance(gather_data, torch.Tensor):
            _output = [torch.zeros_like(gather_data) for _ in range(self.world_size)]
            dist.all_gather(_output, gather_data, async_op=False)
            output = torch.cat(_output)
        else:
            if gather_data[0] is not None:
                _output = [None for _ in range(self.world_size)]
                if hasattr(dist, "all_gather_object"):
                    dist.all_gather_object(_output, gather_data)
                else:
                    ser.wavlm.utils.distributed.all_gather_object(_output, gather_data, self.world_size)
                output = []
                for lst in _output:
                    output.extend(lst)
            else:
                output = None
        return output

    def _finetune_step(self, data):
        waveform = data["waveform"].to(self.device)
        padding_mask = data["padding_mask"].to(self.device)
        y = data["emotion"].to(self.device)
        batch_size = y.shape[0]
        pred, layer_distribution_logits = self.model(waveform, padding_mask)
        # print(f"pred: {pred}")
        if self.cfg.model.output_rep == "elbo":
            loss, kl_loss, ce_loss, layer_distribution_probas = self.elbo_loss_func(
                pred, layer_distribution_logits, y, gamma=self.kl_gamma, target_prior_dist=self.prior_distribution_name
            )
            batch_mean_layer_distribution_probas = layer_distribution_probas.mean(0)
        else:
            loss = self.loss_func(pred[0], y)
            kl_loss, ce_loss = None, None
            if self.cfg.model.output_rep == "weighted_hiddens":
                layer_distribution_probas = F.softmax(layer_distribution_logits, dim=1)
                batch_mean_layer_distribution_probas = layer_distribution_probas.mean(0)
            else:
                batch_mean_layer_distribution_probas = None
        # print(f"loss before")
        # print(f"self.cfg.train.accumulate_each_n_steps: {self.cfg.train.accumulate_each_n_steps}")
        loss = loss / self.cfg.train.accumulate_each_n_steps
        # print(f"loss after")

        loss.backward()

        if self.iteration % self.cfg.train.accumulate_each_n_steps == 0:
            if self.clip_grad:
                grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_value)
            else:
                grad_norm = None

            self.optimizer.step()
            self.optimizer.zero_grad()

        y_pred = [torch.argmax(p, dim=1) for p in pred]

        for i, yp in enumerate(y_pred):
            self.predict_recoder[i].record(yp)
        self.label_recoder.record(y)

        accuracy = [accuracy_score(y.cpu(), yp.cpu()) for yp in y_pred]

        # if self.iteration % self.cfg.train.accumulate_each_n_steps == 0:
        self.loss_meter.update(loss.item())
        if kl_loss:
            self.kl_loss_meter.update(kl_loss.item())
        if ce_loss:
            self.ce_loss_meter.update(ce_loss.item())
        for i, ac in enumerate(accuracy):
            self.acc_meter[i].update(ac, batch_size)
        log_dict = {}
        if self.iteration % self.wandb_train_step_log_interval == 0 and self.local_rank == 0:
            log_dict["train_step/total_loss"] = loss.item()
            log_dict["train_step/total_samples"] = self.iteration * batch_size * self.world_size
            if self.model_output_rep == "elbo":
                log_dict["train_step/kl_loss"] = kl_loss.item()
                log_dict["train_step/ce_loss"] = ce_loss.item()
            for i, ac in enumerate(accuracy):
                log_dict[f"train_accuracy/{i}_layer"] = ac
            if self.iteration % self.cfg.train.accumulate_each_n_steps == 0:
                if grad_norm:
                    log_dict["grad_norm"] = grad_norm
            if batch_mean_layer_distribution_probas is not None:
                for i in range(len(batch_mean_layer_distribution_probas)):
                    log_dict[f"train_step/{i}_layer_probability_weight"] = batch_mean_layer_distribution_probas[i]
            wandb.log(log_dict, step=self.iteration)

        pbar_train_dic = OrderedDict()
        pbar_train_dic["iter"] = self.iteration
        pbar_train_dic["lr"] = self.optimizer.param_groups[0]["lr"]
        for i in range(len(accuracy)):
            pbar_train_dic[f"acc{i}"] = f"{self.acc_meter[i].avg:.5f}"
        pbar_train_dic["loss"] = f"{self.loss_meter.avg:.5f}"
        if self.model_output_rep == "elbo":
            pbar_train_dic["kl_loss"] = f"{self.kl_loss_meter.avg:.5f}"
            pbar_train_dic["ce_loss"] = f"{self.ce_loss_meter.avg:.5f}"
        # print(f"self.loss_meter.avg: {self.loss_meter.avg}")
        assert self.loss_meter.avg is not None, "Loss turned to be None"
        return pbar_train_dic

    def collect_top_k_predictions(self, layer_distribution, k, batch_size, pred):
        best_topk_indices = torch.topk(input=layer_distribution, dim=1, k=k).indices
        best_topk_values = F.softmax(torch.topk(input=layer_distribution, dim=1, k=k).values, dim=1)
        top_k_proba_predictions = torch.zeros((batch_size, len(self.cfg.dataset.classes))).cuda()
        for i in range(batch_size):
            for j in range(k):
                layer = best_topk_indices[i, j]
                weight = best_topk_values[i, j]
                top_k_proba_predictions[i, :] += torch.log_softmax(pred[layer][i], dim=0) * weight
        top_k_prediction = torch.argmax(top_k_proba_predictions, dim=1)
        return top_k_prediction

    @torch.no_grad()
    def _finetune_val_step(self, data, ith_layer_inference):
        waveform = data["waveform"].to(self.device)
        padding_mask = data["padding_mask"].to(self.device)
        y = data["emotion"].to(self.device)
        batch_size = y.shape[0]
        if ith_layer_inference is not None:
            print(f"Starting layer inference... with layer: {ith_layer_inference}")
            pred, layer_distribution_logits = self.model.module.layer_inference(
                waveform, padding_mask, ith_layer_inference
            )
        else:
            pred, layer_distribution_logits = self.model(waveform, padding_mask)
        if self.model_output_rep == "elbo":
            best_layer = torch.argmax(layer_distribution_logits, dim=1)
            choosen_predictions = torch.empty(
                (batch_size, len(self.cfg.dataset.classes))
            ).cuda()  # batch_size, n_classes
            for batch, layer in enumerate(best_layer):
                choosen_predictions[batch, :] = pred[layer][batch]
            # TOP 1 Prediction
            y_pred = torch.argmax(choosen_predictions, dim=1)
            # TOP 3 Prediction
            top_3_prediction = self.collect_top_k_predictions(layer_distribution_logits, 3, batch_size, pred)
            # TOP 5 Prediction
            top_5_prediction = self.collect_top_k_predictions(layer_distribution_logits, 5, batch_size, pred)
            # TOP 12 Prediction
            top_12_prediction = self.collect_top_k_predictions(layer_distribution_logits, 12, batch_size, pred)
            # TOP 24 Prediction
            top_24_prediction = self.collect_top_k_predictions(layer_distribution_logits, 24, batch_size, pred)

        else:
            y_pred = torch.argmax(pred[0], dim=1)
            top_3_prediction, top_5_prediction, top_12_prediction, top_24_prediction = None, None, None, None

        if self.cfg.model.output_rep == "elbo":
            loss, kl_loss, ce_loss, layer_distribution_probas = self.elbo_loss_func(
                pred, layer_distribution_logits, y, gamma=self.kl_gamma, target_prior_dist=self.prior_distribution_name
            )
        else:
            loss = self.loss_func(pred[0], y)
            kl_loss, ce_loss = None, None
            if self.cfg.model.output_rep == "weighted_hiddens":
                layer_distribution_probas = F.softmax(layer_distribution_logits, dim=-1)
            else:
                layer_distribution_probas = None

        if self.model_output_rep == "elbo":
            for i, pred in enumerate(
                    [y_pred, top_3_prediction, top_5_prediction, top_12_prediction, top_24_prediction]
            ):
                if pred is not None:
                    self.predict_recoder[i].record(pred)
                else:
                    print(f"Prediction for top_{self.topk_dict[i]} is None!")
                    assert False
        else:
            for i, pred in enumerate([y_pred]):
                if pred is not None:
                    self.predict_recoder[i].record(pred)
                else:
                    print(f"Prediction for top_{self.topk_dict[i]} is None!")
                    assert False

        self.label_recoder.record(y)
        if self.model_output_rep in ["elbo", "weighted_hiddens"]:
            self.predict_layer_distribution.record(layer_distribution_probas.detach().cpu())
        accuracy = accuracy_score(y.cpu(), y_pred.cpu())
        self.loss_meter.update(loss.item())
        self.acc_meter[0].update(accuracy, batch_size)
        pbar_val_dic = OrderedDict()
        pbar_val_dic[f"acc"] = f"{self.acc_meter[0].avg:.5f}"
        pbar_val_dic["loss"] = f"{self.loss_meter.avg:.5f}"
        if ce_loss:
            self.ce_loss_meter.update(ce_loss.item())
            pbar_val_dic["loss_ce"] = f"{self.ce_loss_meter.avg:.5f}"
        if kl_loss:
            self.kl_loss_meter.update(kl_loss.item())
            pbar_val_dic["loss_kl"] = f"{self.kl_loss_meter.avg:.5f}"
        return pbar_val_dic

    def train_epoch(self):
        torch.cuda.empty_cache()
        self.dataloader_train.set_epoch(self.current_epoch)
        if self.local_rank == 0:
            print(f"-------- {self.cfg.workshop} --------")
        discrip_str = f"Epoch-{self.current_epoch}/{self.epoch}"
        pbar_train = tqdm(self.dataloader_train, disable=self.local_rank != 0, dynamic_ncols=True)
        pbar_train.set_description("Train" + discrip_str)
        self.reset_meters()
        self.reset_recoders()

        self.model.train()
        self.optimizer.zero_grad()
        for data in pbar_train:
            self.iteration += 1
            pbar_train_dic = self._finetune_step(data)
            pbar_train.set_postfix(pbar_train_dic)
        epoch_preds = [self.gather_distributed_data(recoder.data).cpu() for recoder in self.predict_recoder]
        epoch_labels = self.gather_distributed_data(self.label_recoder.data).cpu()
        self.loss_meter.sync_distributed()
        if self.local_rank == 0:
            epoch_loss = self.loss_meter.avg
            if self.model_output_rep == "elbo":
                epoch_kl_loss = self.kl_loss_meter.avg
                epoch_ce_loss = self.ce_loss_meter.avg
            (
                balanced_accuracy,
                macro_recall,
                macro_f1,
                f1_per_class,
                macro_precision,
                confuse_matrix,
                accuracy,
                f1_weighted,
                precision_weighted,
            ) = ([], [], [], [], [], [], [], [], [])
            for pred in epoch_preds:
                ba_ac, ma_recall, ma_f1, f1pcl, ma_pr, cm, acc, wa_f1, wa_pr = self.calculate_score(
                    pred, epoch_labels, labels=self.cfg.dataset.classes[: self.cfg.dataset.num_classes]
                )
                balanced_accuracy.append(ba_ac)
                macro_recall.append(ma_recall)
                macro_f1.append(ma_f1)
                f1_per_class.append(f1pcl)
                macro_precision.append(ma_pr)
                accuracy.append(acc)
                f1_weighted.append(wa_f1)
                precision_weighted.append(wa_pr)
            # wandb.log({"train_epoch/recall": recall})
            # wandb.log({"train_epoch/precision": precision})
            log_dict = {}
            for i, ba_ac, ma_f1 in zip(range(len(macro_f1)), balanced_accuracy, macro_f1):
                log_dict[f"train_epoch/balanced_accuracy/{i + 1}_layer"] = ba_ac
                log_dict[f"train_epoch/f1_macro_/{i + 1}_layer"] = ma_f1
            # for i, _f1 in enumerate(f1_per_class):
            #     wandb.log({f"train/f1_{self.classes_dict[i]}": _f1})
            log_dict["train_epoch/total_loss"] = epoch_loss
            if self.model_output_rep == "elbo":
                log_dict["train_epoch/kl_loss"] = epoch_kl_loss
                log_dict["train_epoch/ce_loss"] = epoch_ce_loss
            log_dict["LR"] = self.optimizer.param_groups[0]["lr"]
            wandb.log(log_dict, step=self.iteration)

    def _get_plot_distribution(self, distribution: np.array, plot_prior: bool) -> Path:
        x = list(range(1, self.cfg.model.encoder_layers + 1))
        posterior_mean = np.mean(distribution, axis=0)
        posterior_std = np.std(distribution, axis=0)
        if plot_prior:
            if self.cfg.model.prior_distribution == "learnable":
                target_distribution = F.softmax(self.target_distribution)
            else:
                target_distribution = self.target_distribution
            prior_mean = (
                target_distribution.tolist()
                if len(self.target_distribution.shape) == 1
                else self.target_distribution[0, :].tolist()
            )
        fig, ax = plt.subplots(figsize=(12, 4))
        ax.errorbar(x, posterior_mean, yerr=posterior_std, linestyle="None", marker="^")
        if plot_prior:
            ax.errorbar(x, prior_mean, yerr=[0] * self.cfg.model.encoder_layers, linestyle="None", marker="*")
        for i in range(len(x)):
            ax.annotate(
                f"{posterior_mean[i]:.2f}",
                (x[i], posterior_mean[i]),
                textcoords="offset points",
                xytext=(0, 10),
                ha="center",
            )
        ax.set_xticks(range(min(x), max(x) + 1, 1))
        plot_filename = Path(self.log_plot_dir) / f"epoch_{self.current_epoch}.png"
        fig.savefig(plot_filename)
        return plot_filename

    def evaluate(self, ith_layer_inference=None, save_model=True):
        assert self.dataloader_val is not None, self.logger_val.info("Validation dataloader is None")
        if self.local_rank == 0:
            print("-------- Validation --------")
        discrip_str = f"Epoch-{self.current_epoch}/{self.epoch}"
        pbar_val = tqdm(self.dataloader_val, disable=self.local_rank != 0, dynamic_ncols=True)
        pbar_val.set_description("Validation" + discrip_str)
        self.reset_meters()
        self.reset_recoders()
        self.model.eval()
        for data in pbar_val:
            pbar_val_dic = self._finetune_val_step(data, ith_layer_inference)
            pbar_val.set_postfix(pbar_val_dic)
        epoch_preds = [self.gather_distributed_data(recoder.data).cpu() for recoder in self.predict_recoder]
        k = len(self.topk_dict.keys())  # for elbo 1 for the rest
        if self.model_output_rep == "elbo":
            epoch_preds = epoch_preds[:k]
        # else:
        #     epoch_preds = epoch_preds[0]
        epoch_labels = self.gather_distributed_data(self.label_recoder.data).cpu()
        self.loss_meter.sync_distributed()
        if self.model_output_rep == "elbo":
            self.kl_loss_meter.sync_distributed()
            self.ce_loss_meter.sync_distributed()
        if self.model_output_rep in ["elbo", "weighted_hiddens"]:
            epoch_layer_dist_preds = self.gather_distributed_data(self.predict_layer_distribution.data)
        # torch.distributed.barrier()
        log_dict = {}
        if self.local_rank == 0:
            epoch_loss = self.loss_meter.avg
            if self.model_output_rep == "elbo":
                kl_loss = self.kl_loss_meter.avg
                ce_loss = self.ce_loss_meter.avg
            #     # Calculate accuracy, recall, f1, precision, confuse_matrix
            (
                balanced_accuracy,
                macro_recall,
                macro_f1,
                f1_per_class,
                macro_precision,
                confuse_matrix,
                accuracy,
                f1_weighted,
                precision_weighted,
            ) = ([], [], [], [], [], [], [], [], [])
            for i, epoch_pred in enumerate(epoch_preds):
                ba_ac, ma_recall, ma_f1, f1pcl, ma_pr, cm, acc, wa_f1, wa_pr = self.calculate_score(
                    epoch_pred, epoch_labels, labels=self.cfg.dataset.classes[: self.cfg.dataset.num_classes]
                )
                balanced_accuracy.append(ba_ac)
                macro_recall.append(ma_recall)
                macro_f1.append(ma_f1)
                f1_per_class.append(f1pcl)
                macro_precision.append(ma_pr)
                accuracy.append(acc)
                f1_weighted.append(wa_f1)
                precision_weighted.append(wa_pr)
                log_dict[f"val/balanced_accuracy_{self.topk_dict[i]}"] = balanced_accuracy[i]
                log_dict[f"val/accuracy_{self.topk_dict[i]}"] = accuracy[i]
                log_dict[f"val/macro_recall_{self.topk_dict[i]}"] = macro_recall[i]
                log_dict[f"val/macro_precision_{self.topk_dict[i]}"] = macro_precision[i]
                log_dict[f"val/weighted_precision_{self.topk_dict[i]}"] = precision_weighted[i]
                log_dict[f"val/f1_macro_{self.topk_dict[i]}"] = macro_f1[i]
                log_dict[f"val/f1_weighted_{self.topk_dict[i]}"] = f1_weighted[i]
            if self.model_output_rep in ["elbo", "weighted_hiddens"]:
                epoch_layer_dist_preds = np.stack(epoch_layer_dist_preds, axis=0)
                plot_path = self._get_plot_distribution(
                    epoch_layer_dist_preds, plot_prior=self.model_output_rep == "elbo"
                )
                log_dict[f"Validation predictions/epoch_{self.current_epoch}"] = wandb.Image(str(plot_path))
            is_best = []
            for i, acc in enumerate(balanced_accuracy):
                if i == self.target_index:
                    if self.best_score[i] >= acc:
                        self.early_stopping += 1
                    else:
                        self.early_stopping = torch.zeros(1).to(self.device)
                    print(f"Early stopping is: {self.early_stopping} / 3")
                is_best.append(acc > self.best_score[i])
                self.best_score[i] = max(self.best_score[i], acc)
                log_dict[f"best/balanced_accuracy_top{self.topk_dict[i]}"] = self.best_score[i]
            log_dict["val/total_loss"] = epoch_loss
            if self.model_output_rep == "elbo":
                log_dict["val/ce_loss"] = ce_loss
                log_dict["val/kl_loss"] = kl_loss
            # for i, _f1 in enumerate(f1_per_class):
            #     print(self.best_scores, "BEST SCORES")
            #     print(_f1, "F1")
            #     if self.best_scores[i] < _f1:
            #         self.best_scores[i] = _f1
            #     log_dict[f"val/f1_{self.classes_dict[i]}"] = _f1
            # for i, _f1 in enumerate(self.best_scores):
            #     log_dict[f"best/f1_{self.classes_dict[i]}"] = _f1
            wandb.log(log_dict, step=self.iteration)
            self.logger_val.info(
                f"Testing epoch: {self.current_epoch}, "
                f"balanced_accuracy: {balanced_accuracy[0]:.5f}, "
                f"macro precision: {macro_precision[0]:.5f}, "
                f"macro recall: {macro_recall[0]:.5f}, "
                f"macro F1: {macro_f1[0]:.5f}, "
                f"loss: {epoch_loss:.5f}, "
            )
            # score = f1
            # is_best = score > self.best_score
            # self.best_score = max(self.best_score, score)
            if save_model:
                print("Started model saving...")
                for i, best in enumerate(is_best):
                    if i == self.target_index:
                        self.model_save(best, idx=i)
                print("Ended model saving...")
        dist.broadcast(self.early_stopping, src=0)
        print(f"Broadcast early stopping value to rank {self.local_rank}: {self.early_stopping} / 3")

    @torch.no_grad()
    def test(self, use_val_loader=False, ith_layer_inference=None):
        if self.local_rank == 0:
            print("-------- Testing --------")
        k = self.topk_dict[self.target_index]
        filename = f"model_best_val_top_{k}.pt"
        best_ckpt_path = os.path.join(self.cfg.ckpt_save_path, filename)
        best_ckpt = torch.load(best_ckpt_path, map_location=self.device)
        self.model.module.load_state_dict(best_ckpt["model"])
        print(f"Loaded weights for the model with the validation WA score: {best_ckpt['best_score']}")
        self.model.eval()
        if os.path.exists(os.path.join(self.cfg.ckpt_save_path, f"model_second_best_val_top_{k}.pt")):
            os.remove(os.path.join(self.cfg.ckpt_save_path, f"model_second_best_val_top_{k}.pt"))
        discrip_str = f"Epoch-{self.current_epoch}/{self.epoch}"
        pbar_test = tqdm(
            self.dataloader_val if use_val_loader else self.dataloader_test,
            disable=self.local_rank != 0,
            dynamic_ncols=True,
        )
        pbar_test.set_description("Testing" + discrip_str)
        self.reset_meters()
        self.reset_recoders()
        for data in pbar_test:
            pbar_test_dic = self._finetune_val_step(data, ith_layer_inference)
            pbar_test.set_postfix(pbar_test_dic)
        test_preds = [self.gather_distributed_data(recoder.data).cpu() for recoder in self.predict_recoder]
        if self.model_output_rep == "elbo":
            test_preds = test_preds[: len(self.topk_dict.keys())]
            self.kl_loss_meter.sync_distributed()
            self.ce_loss_meter.sync_distributed()
        else:
            test_preds = [test_preds[0]]
        test_labels = self.gather_distributed_data(self.label_recoder.data).cpu()
        self.loss_meter.sync_distributed()
        # torch.distributed.barrier()
        log_dict = {}
        if self.local_rank == 0:
            test_loss = self.loss_meter.avg
            if self.model_output_rep == "elbo":
                kl_loss = self.kl_loss_meter.avg
                ce_loss = self.ce_loss_meter.avg
            (
                balanced_accuracy,
                macro_recall,
                macro_f1,
                f1_per_class,
                macro_precision,
                confuse_matrix,
                accuracy,
                f1_weighted,
                precision_weighted,
            ) = ([], [], [], [], [], [], [], [], [])
            for i, epoch_pred in enumerate(test_preds):
                ba_ac, ma_recall, ma_f1, f1pcl, ma_pr, cm, acc, wa_f1, wa_pr = self.calculate_score(
                    epoch_pred, test_labels, labels=self.cfg.dataset.classes[: self.cfg.dataset.num_classes]
                )
                balanced_accuracy.append(ba_ac)
                macro_recall.append(ma_recall)
                macro_f1.append(ma_f1)
                f1_per_class.append(f1pcl)
                macro_precision.append(ma_pr)
                accuracy.append(acc)
                f1_weighted.append(wa_f1)
                precision_weighted.append(wa_pr)
                log_dict[f"test/balanced_accuracy_{self.topk_dict[i]}"] = balanced_accuracy[i]
                log_dict[f"test/accuracy_{self.topk_dict[i]}"] = accuracy[i]
                log_dict[f"test/macro_recall_{self.topk_dict[i]}"] = macro_recall[i]
                log_dict[f"test/macro_precision_{self.topk_dict[i]}"] = macro_precision[i]
                log_dict[f"test/weighted_precision_{self.topk_dict[i]}"] = precision_weighted[i]
                log_dict[f"test/f1_macro_{self.topk_dict[i]}"] = macro_f1[i]
                log_dict[f"test/f1_weighted_{self.topk_dict[i]}"] = f1_weighted[i]
            df_dict = {"label": test_labels}
            for i, acc in enumerate(accuracy):
                df_dict[f"top{self.topk_dict[i]}"] = test_preds[i]
            log_dict["test/total_loss"] = test_loss
            if self.model_output_rep == "elbo":
                log_dict["test/ce_loss"] = ce_loss
                log_dict["test/kl_loss"] = kl_loss
            wandb.log(log_dict, step=self.iteration)
            df = pd.DataFrame(df_dict).reset_index(drop=True)
            df.to_csv(os.path.join(self.cfg.workshop, "test_predictions_log_softmax.csv"), index=False)

    def model_save(self, is_best=False, idx=0):
        if is_best:
            save_dict = {
                "cfg": self.cfg,
                "epoch": self.current_epoch,
                "iteration": self.iteration,
                "model": self.model.module.state_dict(),
                "optimizer": self.optimizer.state_dict(),
                "scheduler": self.scheduler.state_dict(),
                "best_score": self.best_score[idx],
            }
            k = self.topk_dict[idx]
            best_path = os.path.join(self.cfg.ckpt_save_path, f"model_best_val_top_{k}.pt")
            second_best_path = os.path.join(self.cfg.ckpt_save_path, f"model_second_best_val_top_{k}.pt")
            if Path(best_path).exists():
                shutil.move(best_path, second_best_path)
            torch.save(save_dict, best_path)

    @torch.no_grad()
    def collect_predictions(self):
        if self.local_rank == 0:
            self.model.eval()
            discrip_str = "Collecting eval predictions"
            pbar_val = tqdm(self.dataloader_val, dynamic_ncols=True)
            pbar_val.set_description(discrip_str)
            # Create draft for dictionary
            result_df = {"audio": [], "label": []}
            n_outputs = self.cfg.model.encoder_layers if self.cfg.model.output_rep == "elbo" else 1
            for i in range(self.cfg.model.encoder_layers):
                result_df[f"layer_weight_{i}"] = []
            for i in range(n_outputs):
                for j in range(self.cfg.model.num_classes):
                    result_df[f"logit_layer_{i}_class_{j}"] = []
            # Iterating through batches
            for data in pbar_val:
                waveform = data["waveform"].to(self.device)
                padding_mask = data["padding_mask"].to(self.device)
                y = data["emotion"].to(self.device)
                batch_size = y.shape[0]
                predicted_logits, layer_distribution_logits = self.model(waveform, padding_mask)
                layer_distribution_probas = F.softmax(layer_distribution_logits, dim=1)
                # Iterating within one batch and add values
                for b_idx in range(batch_size):
                    result_df["audio"].append(data["waveform"][b_idx].squeeze(0).tolist())
                    result_df["label"].append(int(y[b_idx]))
                    for i in range(self.cfg.model.encoder_layers):
                        result_df[f"layer_weight_{i}"].append(float(layer_distribution_probas[b_idx, i]))
                    for i in range(n_outputs):
                        for j in range(self.cfg.model.num_classes):
                            result_df[f"logit_layer_{i}_class_{j}"].append(float(predicted_logits[i][b_idx, j]))
            df = pd.DataFrame(result_df)
            df.to_csv(f"{self.cfg.workshop}/logits_and_weights.csv", index=False)

    def run(self):
        self.prepare_staff()
        while self.current_epoch < self.epoch:
            self.train_epoch()
            self.scheduler.step()
            if self.current_epoch % self.wandb_val_epoch_interval == 0:
                self.evaluate(save_model=self.cfg.train.save_model_val)
            self.current_epoch += 1
            if self.early_stopping > 2:
                dist.barrier()
                print(f"Early stopping (patience: {self.early_stopping})")
                break
        if self.cfg.dataset.have_test_set:
            print("Having test set!")
            self.test(use_val_loader=False)  # use validation loader for cross-validation
        self.cleanup()

    def cleanup(self):
        if self.logger_train is not None:
            ser.wavlm.utils.logger.close_logger(self.logger_train)
        if self.logger_val is not None:
            ser.wavlm.utils.logger.close_logger(self.logger_val)
        torch.cuda.empty_cache()

        self.current_epoch = 0
        self.iteration = 0
        self.best_score = 0 if self.mode == "_finetune" else None
