import random
import os
import numpy as np
import torch
from torchmetrics import Accuracy
import logging
import torch
import torchvision
import torch.nn.functional as F
import numpy as np
from collections import defaultdict
from src.pl_model.distillation import Distilltion
import wandb
from omegaconf import DictConfig, OmegaConf
from src.utils.load_models import load_models, load_model3, load_model_qktD
from hydra.utils import instantiate
import copy

log = logging.getLogger(__name__)


class MyDistillationMutlipleTeachers_QKT(Distilltion):

    def __init__(self, cfg):
        super().__init__(cfg=cfg)
        self.save_hyperparameters(cfg)
        self.cfg = cfg  # Save cfg to self
        self.num_classes = cfg.num_classes
        self.run_id = cfg.train_exp_id
        self.learner_client = self.hparams.learner_client
        self.teacher_client = self.hparams.teacher_client
        self.data_dists_vect = self.get_data_dists_vect(cfg)
        self.normalize_alpha = self.hparams.normalize_alpha
        self.noise_threshold = cfg.noise_threshold

        self.teacher_model = load_models(clients_ids=cfg.teacher_client, clients=cfg.clients,
                                         model=instantiate(cfg.model))
        self.learner_model = copy.deepcopy(self.teacher_model[self.learner_client])

        if not self.hparams.stage1_T:
            self.T = self.hparams.KL_temperature

        else:
            self.T = self.hparams.stage1_T

        print(f"using T: {self.T}")

        self.qkt_unweighted_teachers = self.hparams.qkt_unweighted_teachers
        self.goal_class = self.hparams.goal_class
        print(f"self.goal_classes: {self.goal_class}")

        self.centralized_stage1 = self.hparams.centralized_qkt and not cfg.stage2
        if self.centralized_stage1:
            self.goal_class = [i for i in range(self.num_classes)]

        # Use get_alphas_data_free if cfg.alpha_data_free is True
        if self.hparams.alpha_data_free and not self.centralized_stage1:
            if self.hparams.data_free_option == 1:
                self.alphas = torch.stack(self.get_alphas_data_free(self.teacher_model))
                print(f"using the data_free option1")
            else: # default
                self.alphas = torch.stack(self.get_alphas_data_free3(self.teacher_model))
            self.learner_alpha = [1 for _ in range(self.num_classes)]
            # if cfg.personalized_qkt: #TODO: try this
            #     self.learner_alpha = [2 for _ in range(self.num_classes)]

        elif self.qkt_unweighted_teachers:
            self.alphas =  [1 for _ in range(self.num_classes)]
            self.learner_alpha = [1 for _ in range(self.num_classes)]

        else:
            self.alphas = torch.stack(self.get_alphas(self.hparams.teacher_client))
            self.learner_alpha = 1 - torch.sum(self.alphas, axis=0)
            self.alphas = self.adjust_alphas(self.alphas)  # for the QKT

        self.alpha = self.alphas
        self.with_CE = self.hparams.with_CE
        print(f">> self.with_CE: {self.with_CE}")

        self.detailed_testing = self.hparams.detailed_testing
        self.best_query_class_acc_gain_epoch = 0
        self.least_forgetting_epoch = 0
        self.best_val_acc_epoch = 0
        self.best_simple_weighted_accuracy_epoch = 0
        self.best_val_acc = float('-inf')
        self.best_query_class_acc_gain = float('-inf')
        self.least_forgetting = float('-inf')
        self.best_simple_weighted_accuracy = float('-inf')
        self.best_uniform_accuracy = float('-inf')

        if not self.qkt_unweighted_teachers:
            log.info(f"Alphas are:\n{self.alphas}")
            log.info(f"Learner Alpha is:\n{self.learner_alpha}")
            log.info(f"Alphas are:\n{self.alphas}")
        else:
            log.info(f"Using qkt_unweighted_teachers!")

        # if self.cfg.measure_pre_transfer_acc:
        #     self.pre_transfer_acc = self.calculate_per_class_accuracy(self.learner_model)

    def on_fit_start(self):
        self.learner_alpha = torch.tensor(self.learner_alpha, dtype=torch.float, device=self.device)
        self.alphas = [torch.tensor(a, dtype=torch.float, device=self.device) for a in self.alphas]
        self.teacher_model = [t.to(self.device) for t in self.teacher_model]
        self.learner_model = self.learner_model.to(self.device)

        if self.cfg.measure_pre_transfer_acc:
            self.pre_transfer_acc = self.calculate_per_class_accuracy(self.learner_model)

    def calculate_per_class_accuracy(self, model):
        """
        Calculates per-class accuracy for a given model using the test dataloader.
        """
        self.learner_model.eval()
        correct = torch.zeros(self.num_classes).to(self.device)
        total = torch.zeros(self.num_classes).to(self.device)

        with torch.no_grad():
            for x, y in self.trainer.datamodule.test_dataloader():
                x, y = x.to(self.device), y.to(self.device)
                outputs = model(x)
                _, predicted = torch.max(outputs, 1)
                c = (predicted == y).squeeze()

                for i in range(len(y)):
                    label = y[i]
                    correct[label] += c[i].item()
                    total[label] += 1

        per_class_accuracy = correct / total
        return per_class_accuracy.cpu().numpy()

    def calculate_client_all_accuracies(self, per_class_acc):
        """
        Calculates various custom accuracy metrics for a specified client including measures before and after knowledge transfer.

        Parameters:
        - cfg: Configuration object containing client and training configurations including `learner_client` and `query_classes`.
        - per_class_acc (list): List of accuracies per class for the specified client after training.
        - data_dists_vectorized (dict): Dictionary mapping clients to their class distribution data.

        Returns:
        - tuple: Contains the calculated accuracies using different strategies, including the effect of forgetting.
        """

        learner_client = self.learner_client

        client_name = f'client_{learner_client}'
        data_dists_vectorized = self.data_dists_vect
        class_distribution = data_dists_vectorized[client_name]
        query_classes = self.goal_class

        # Initialize variables for each strategy
        total_uniform_acc = total_weighted_acc = total_query_class_acc = total_local_class_acc = 0
        count_uniform_classes = total_weight = total_local_weight = 0

        num_classes = len(per_class_acc)
        print(f"num_classes: {num_classes}")
        query_class_acc = [per_class_acc[i] for i in query_classes]

        if self.cfg.measure_pre_transfer_acc:
            pre_transfer_acc = self.pre_transfer_acc
        else:
            train_run_id = self.run_id
            api = wandb.Api()
            train_run = api.run(train_run_id)
            train_run_summary = train_run.summary._json_dict
            pre_transfer_acc = train_run_summary[f'client-{learner_client}/per_class_test_acc']

        query_class_acc_gain = [(per_class_acc[i] - pre_transfer_acc[i]) for i in query_classes]

        for cls_index in range(num_classes):
            if class_distribution[cls_index] > 0 or cls_index in query_classes:
                # Calculate uniform accuracy
                total_uniform_acc += per_class_acc[cls_index]
                count_uniform_classes += 1

                # Simple weighted accuracy (query classes get weight 1)
                weight = 1 if cls_index in query_classes else class_distribution[cls_index] / sum(class_distribution)
                total_weighted_acc += weight * per_class_acc[cls_index]
                total_weight += weight

                # Local classes accuracy excluding query classes
                if cls_index not in query_classes:
                    local_weight = class_distribution[cls_index] / sum(class_distribution)
                    total_local_class_acc += local_weight * per_class_acc[cls_index]
                    total_local_weight += local_weight

        # Compute accuracies for each strategy
        uniform_accuracy = total_uniform_acc / count_uniform_classes if count_uniform_classes > 0 else 0
        simple_weighted_accuracy = total_weighted_acc / total_weight if total_weight > 0 else 0
        query_classes_accuracy = sum(query_class_acc) / len(query_classes) if query_classes else 0
        query_classes_acc_gain = sum(query_class_acc_gain) / len(query_class_acc_gain) if query_classes else 0
        local_classes_accuracy = total_local_class_acc / total_local_weight if total_local_weight > 0 else 0
        forgetting = sum((per_class_acc[j] - pre_transfer_acc[j]) for j in range(len(per_class_acc)) if
                         (per_class_acc[j] - pre_transfer_acc[j]) < 0) / len(
            [accuracy for accuracy in pre_transfer_acc if accuracy > 0])

        print(f"uniform_accuracy: {uniform_accuracy}")
        print(f"simple_weighted_accuracy: {simple_weighted_accuracy}")
        print(f"query_classes_accuracy: {query_classes_accuracy}")
        print(f"query_classes_acc_gain: {query_classes_acc_gain}")
        print(f"local_classes_accuracy: {local_classes_accuracy}")
        print(f"forgetting: {forgetting}")

        return (
            uniform_accuracy, simple_weighted_accuracy, query_classes_accuracy, query_classes_acc_gain,
            local_classes_accuracy,
            forgetting)

    def get_alphas(self, teachers):
        tas = [self.get_per_class_accuracy(teacher) for teacher in teachers]
        la = torch.tensor(self.get_per_class_accuracy(self.learner_client), dtype=torch.float)
        alphas = [torch.nan_to_num(ta / (sum(tas) + la)) for ta in tas]
        return alphas

    def get_alphas_data_free(self, teacher_models):  # mask the binary_class_weights
        """
        Calculates alphas with only the classes present in the learner model or the goal class.
        Classes not present are set to zero.
        Teachers with binary class weights not zero for learner or goal classes get equal weight.
        Goal classes have slightly increased weight.
        """

        print(f">> using (get_alphas_data_free)")
        binary_teacher_weights = self.identify_binary_teacher_weights(teacher_models)
        la = torch.tensor(self.get_per_class_accuracy(self.learner_client), dtype=torch.float)

        alphas = []

        # Create a mask for the goal classes and learner classes
        mask = torch.zeros(self.num_classes)
        for i in range(self.num_classes):
            if la[i] != 0 or i in self.goal_class:
                mask[i] = 1

        print(f"mask: {mask}")

        # Calculate alphas based on the mask and binary teacher weights
        for teacher_idx, teacher_weights in enumerate(binary_teacher_weights):
            alpha = torch.zeros(self.num_classes)
            for i in range(self.num_classes):
                # Debugging logs to check the conditions
                log.info(f"Teacher {teacher_idx}, Class {i}, Binary Weight: {teacher_weights[i]}, Mask: {mask[i]}")
                if teacher_weights[i] == 1:
                    alpha[i] = mask[i]
                    log.info(f"Setting alpha[{i}] for teacher {teacher_idx} to {mask[i]}")

            print(f"alpha: {alpha}")
            alphas.append(alpha)

        # if self.cfg.normalize_alpha:
        #     # Normalize the alpha values by the number of teachers who have the class
        #     for i in range(self.num_classes):
        #         teacher_count = sum(1 for teacher_weights in binary_teacher_weights if teacher_weights[i] == 1)
        #         if teacher_count > 0:
        #             for alpha in alphas:
        #                 alpha[i] = alpha[i] / teacher_count

        # Adjust the alpha values for goal classes to slightly increase their weight
        for alpha in alphas:
            for g in self.goal_class:
                if alpha[g] != 0:
                    alpha[g] *= self.hparams.goal_class_boost

        return alphas

    def get_alphas_data_free2(self, teacher_models):
        """
        Calculates alphas for each teacher model by considering only the classes present in the learner model or the goal classes.
        Classes not present in the learner model or the goal classes are set to zero.
        >> Teachers are only considered if they have any of the goal classes with a binary weight of 1.
        For these teachers, irrelevant classes are masked out, and only the classes of the learner and the goal classes are kept.
        Goal classes have their weights increased.

        Parameters:
        teachers (list): List of teacher models.

        Returns:
        alphas (list of tensors): List of alpha tensors, one for each teacher, representing the importance of each class.
        """

        print(f">> using (get_alphas_data_free2)")
        binary_teacher_weights = self.identify_binary_teacher_weights(teacher_models)
        la = torch.tensor(self.get_per_class_accuracy(self.learner_client), dtype=torch.float)

        alphas = []

        # Create a mask for the goal classes and learner classes
        mask = torch.zeros(self.num_classes)
        for i in range(self.num_classes):
            if la[i] != 0 or i in self.goal_class:
                mask[i] = 1

        print(f"mask: {mask}")

        # Calculate alphas based on the mask and binary teacher weights
        for teacher_idx, teacher_weights in enumerate(binary_teacher_weights):
            alpha = torch.zeros(self.num_classes)
            has_goal_class = any(teacher_weights[i] == 1 for i in self.goal_class)

            if has_goal_class:
                for i in range(self.num_classes):
                    # Debugging logs to check the conditions
                    log.info(f"Teacher {teacher_idx}, Class {i}, Binary Weight: {teacher_weights[i]}, Mask: {mask[i]}")
                    if teacher_weights[i] == 1:
                        alpha[i] = mask[i]
                        log.info(f"Setting alpha[{i}] for teacher {teacher_idx} to {mask[i]}")

                # Adjust the alpha values for goal classes to slightly increase their weight
                for g in self.goal_class:
                    if alpha[g] != 0:
                        alpha[g] *= self.hparams.goal_class_boost
            print(f"alpha: {alpha}")
            alphas.append(alpha)

        return alphas

    def get_alphas_data_free3(self, teacher_models):
        """
        Calculates alphas for each teacher model by considering only the classes present in the learner model or the goal classes.
        Classes not present in the learner model or the goal classes are set to zero.
        All teachers have alpha, which is the mask indicating relevant classes.
        Goal classes have their weights slightly increased.
        The teacher with the same index as the learner client is always included to mitigate forgetting.

        Parameters:
        teachers (list): List of teacher models.

        Returns:
        alphas (list of tensors): List of alpha tensors, one for each teacher, representing the importance of each class.
        """

        # Identify binary teacher weights
        binary_teacher_weights = self.identify_binary_teacher_weights(teacher_models)

        # Use number of samples or per-class accuracy to identify learner classes
        if self.cfg.use_number_of_samples:
            # Use data_dists_vect to identify learner classes if configured
            la = torch.tensor(self.data_dists_vect[f'client_{self.learner_client}'])
        else:
            # Otherwise, use the per-class accuracy
            la = torch.tensor(self.get_per_class_accuracy(self.learner_client), dtype=torch.float)

        alphas = []

        # Create a mask based on the cfg.only_goal_classes setting
        mask = torch.zeros(self.num_classes)
        learner_client_mask = torch.zeros(self.num_classes)
        if self.cfg.only_goal_classes:
            # If only_goal_classes is true, set the mask only for goal classes with the boost
            for i in range(self.num_classes):
                if i in self.goal_class:
                    mask[i] = self.hparams.goal_class_boost
                if la[i] != 0:
                    learner_client_mask[i] = 1

        else:
            # If only_goal_classes is false, include both learner classes and goal classes
            for i in range(self.num_classes):
                if i in self.goal_class:
                    mask[i] = self.hparams.goal_class_boost
                if la[i] != 0:
                    mask[i] = 1
                    learner_client_mask[i] = 1

        # Check if any teacher was detected to have the goal class
        goal_class_detected = any(
            any(teacher_weights[i] == 1 for i in self.goal_class)
            for teacher_weights in binary_teacher_weights
        )

        # Assign alphas based on whether the goal class was detected
        if goal_class_detected:
            # If any teacher has the goal class, assign alphas accordingly
            for teacher_idx, teacher_weights in enumerate(binary_teacher_weights):
                has_goal_class = any(teacher_weights[i] == 1 for i in self.goal_class)
                if has_goal_class or teacher_idx == self.learner_client:
                    alpha = mask.clone()

                    if self.cfg.only_goal_classes and self.hparams.copy_of_self_as_teacher and teacher_idx == self.learner_client:
                        alpha = learner_client_mask.clone()
                else:
                    alpha = torch.zeros(self.num_classes)

                alphas.append(alpha)
        else:
            # If no teacher has the goal class, include all teachers
            log.info("No teacher with goal class found. Adding all teachers.")
            alphas = [mask.clone() for _ in binary_teacher_weights]
            alphas[self.learner_client] = learner_client_mask.clone()

        # Log the list of used teachers
        used_teachers = [i for i, _ in enumerate(binary_teacher_weights) if any(alphas[i] != 0)]
        # self.logger.experiment.summary["used_teachers"] = used_teachers
        log.info(f"Used teachers: {used_teachers}")

        return alphas

    def get_alphas_data_free4(self, teacher_models):  # alpha is the mask
        """
        Calculates alphas for each teacher model by considering only the classes present in the learner model or the goal classes.
        Classes not present in the learner model or the goal classes are set to zero.
        All teachers have the same alpha, which is the mask indicating relevant classes.
        Goal classes have their weights increased.

        Parameters:
        teachers (list): List of teacher models.

        Returns:
        alphas (list of tensors): List of alpha tensors, one for each teacher, representing the importance of each class.
        """

        print(f">> using (get_alphas_data_free4)")
        binary_teacher_weights = self.identify_binary_teacher_weights(teacher_models)
        la = torch.tensor(self.get_per_class_accuracy(self.learner_client), dtype=torch.float)

        alphas = []

        # Create a mask for the goal classes and learner classes
        mask = torch.zeros(self.num_classes)
        for i in range(self.num_classes):
            if la[i] != 0 or i in self.goal_class:
                mask[i] = 1

        print(f"mask: {mask}")

        # Adjust the mask values for goal classes to slightly increase their weight
        for g in self.goal_class:
            if mask[g] != 0:
                mask[g] *= self.hparams.goal_class_boost

        # Assign the same mask to all teachers if they have any goal class with binary weight 1
        for teacher_idx, teacher_weights in enumerate(binary_teacher_weights):
            # has_goal_class = any(teacher_weights[i] == 1 for i in self.goal_class)
            alpha = mask.clone()
            print(f"alpha for teacher {teacher_idx}: {alpha}")
            alphas.append(alpha)

        return alphas

    def adjust_alphas(self, alphas):
        z = torch.zeros_like(alphas)
        for i, alpha in enumerate(alphas):
            for g in self.goal_class:
                if (alpha[g] != 1 and alpha[g] != 0 and alpha[g] <= 0.8):
                    z[i][g] = alpha[g].clone().detach() + 0.2
                elif (alpha[g] == 1):
                    z[i][g] = 1
                else:
                    z[i][g] = alpha[g].clone().detach()
            for l in range(self.num_classes):
                if (self.learner_alpha[l] != 0 and alpha[l] != 0):
                    z[i][l] = alpha[l].clone().detach()
        return z

    def get_per_class_accuracy(self, client_id):
        """
        Fetches the per-class accuracy for a given client ID from Weights & Biases.
        """
        api = wandb.Api()
        run_id = self.run_id
        run = api.run(run_id)

        name = f"client-{client_id}/val_per_class_acc"
        val_acc = run.summary[name]
        val_acc = torch.tensor(val_acc)
        return val_acc

    def training_step(self, batch, batch_idx):
        x, y = batch
        if len(y.shape) > 1:
            y = y.squeeze(1)  # Squeeze the labels to ensure they are 1D  (for medMNIST dataset)

        learner_logits = self(x)

        onehot_y = F.one_hot(y, self.num_classes).to(torch.float)

        ce = 0
        if self.with_CE:
            ce = F.kl_div(
                F.log_softmax(learner_logits, dim=1),
                onehot_y,
                reduction='none'
            )
            if not self.qkt_unweighted_teachers and not self.hparams.alpha_data_free:  # TODO: REVISE the data-free case
                ce = ce * self.learner_alpha
            ce = ce.sum() / ce.size()[0]  # by batch size

            self.log(
                f"{self.exp_name}/learner-ce_loss",
                ce, on_step=True, on_epoch=False, prog_bar=True
            )

        divergences = 0
        for (alpha, t_model) in zip(self.alphas, self.teacher_model):
            with torch.no_grad():
                teacher_logits = t_model(x)

            divergence = F.kl_div(
                F.log_softmax(learner_logits / self.T, dim=1),
                F.softmax(teacher_logits / self.T, dim=1),
                reduction='none'
            )
            if not self.qkt_unweighted_teachers:
                divergence = divergence * alpha

            divergence = divergence.sum() / divergence.size()[0] * self.T * self.T

            divergences += divergence

        divergences = divergences

        self.log(
            f"{self.exp_name}/learner-kl_loss",
            divergences, on_step=True, on_epoch=False, prog_bar=True
        )

        preds = torch.argmax(learner_logits, dim=1)
        acc = self.train_acc(preds, y)
        self.log(
            f"{self.exp_name}/train_acc",
            acc, on_step=False, on_epoch=True, prog_bar=False
        )

        if self.with_CE:
            loss = ce + divergences
        else:
            loss = divergences

        return {"loss": loss, "preds": preds, "targets": y}

    def get_data_dists_vect(self, cfg):
        num_classes = self.num_classes
        total_num_samples_per_class = defaultdict(int)
        data_dists_vectorized = {}
        for client, info in cfg.clients.items():
            data_dist = cfg.clients[client]["train_data_distribution"]
            data_dist_vectorized = np.array(
                [data_dist.get(f"{cls_idx}") if data_dist.get(f"{cls_idx}") else 0 for cls_idx in range(num_classes)])
            data_dists_vectorized[client] = data_dist_vectorized

            for cls_idx, count in data_dist.items():
                total_num_samples_per_class[cls_idx] += count

        total_num_samples = sum(total_num_samples_per_class.values())

        log.info(f"data_dists_vectorized: {data_dists_vectorized}")
        log.info(f"total_num_samples_per_class: {total_num_samples_per_class}")
        log.info(f"total_num_samples: {total_num_samples}")

        return data_dists_vectorized

    def validation_epoch_end(self, outputs):
        acc = self.val_acc.compute()  # get val accuracy from current epoch
        self.val_acc_best.update(acc)

        self.log(
            f"{self.exp_name}/best_val_acc",
            self.val_acc_best.compute(), on_epoch=True, prog_bar=True
        )

        confusion_matrix = self.val_confusion_matrix.compute()
        confusion_matrix = confusion_matrix / confusion_matrix.sum(axis=1)
        self.per_class_val_acc = np.diag(confusion_matrix.cpu().detach().numpy())
        self.private_val_acc = acc

        self.logger.experiment.summary[f"{self.exp_name}/val_acc"] = acc
        self.logger.experiment.summary[f"{self.exp_name}/val_per_class_acc"] = self.per_class_val_acc.tolist()

        if self.detailed_testing:

            uniform_accuracy, simple_weighted_accuracy, query_classes_accuracy, query_classes_acc_gain, local_classes_accuracy, forgetting = self.calculate_client_all_accuracies(
                self.per_class_val_acc)

            # Define checkpoint paths
            val_acc_ckpt_path = os.path.join(self.trainer.default_root_dir, "best_val_acc.ckpt")
            query_class_acc_gain_ckpt_path = os.path.join(self.trainer.default_root_dir,
                                                          "val_best_query_class_acc_gain.ckpt")
            least_forgetting_ckpt_path = os.path.join(self.trainer.default_root_dir, "val_least_forgetting.ckpt")
            simple_weighted_accuracy_ckpt_path = os.path.join(self.trainer.default_root_dir,
                                                              "val_best_simple_weighted_accuracy.ckpt")

            uniform_accuracy_ckpt_path = os.path.join(self.trainer.default_root_dir,
                                                      "val_best_uniform_accuracy.ckpt")

            # Ensure the model is in eval mode
            self.eval()

            # Save the current model state
            current_state_dict = {k: v.clone().detach().cpu() for k, v in self.learner_model.state_dict().items()}

            # Save the best checkpoints for different metrics (and delete the prev)
            if acc > self.best_val_acc:
                self.best_val_acc = acc
                self.best_val_acc_epoch = self.current_epoch
                if os.path.exists(val_acc_ckpt_path):
                    os.remove(val_acc_ckpt_path)
                torch.save(current_state_dict, val_acc_ckpt_path)

            if query_classes_acc_gain > self.best_query_class_acc_gain:
                self.best_query_class_acc_gain = query_classes_acc_gain
                self.best_query_class_acc_gain_epoch = self.current_epoch
                if os.path.exists(query_class_acc_gain_ckpt_path):
                    os.remove(query_class_acc_gain_ckpt_path)
                torch.save(current_state_dict, query_class_acc_gain_ckpt_path)

            if forgetting > self.least_forgetting:  # negative values, larger means less forgetting
                self.least_forgetting = forgetting
                self.least_forgetting_epoch = self.current_epoch
                if os.path.exists(least_forgetting_ckpt_path):
                    os.remove(least_forgetting_ckpt_path)
                torch.save(current_state_dict, least_forgetting_ckpt_path)

            if simple_weighted_accuracy > self.best_simple_weighted_accuracy:
                self.best_simple_weighted_accuracy = simple_weighted_accuracy
                self.best_simple_weighted_accuracy_epoch = self.current_epoch
                if os.path.exists(simple_weighted_accuracy_ckpt_path):
                    os.remove(simple_weighted_accuracy_ckpt_path)
                torch.save(current_state_dict, simple_weighted_accuracy_ckpt_path)

            if uniform_accuracy > self.best_uniform_accuracy:
                self.best_uniform_accuracy = uniform_accuracy
                self.best_uniform_accuracy_epoch = self.current_epoch
                if os.path.exists(uniform_accuracy_ckpt_path):
                    os.remove(uniform_accuracy_ckpt_path)
                torch.save(current_state_dict, uniform_accuracy_ckpt_path)

            self.logger.experiment.summary[f"{self.exp_name}/val_simple_weighted_accuracy"] = simple_weighted_accuracy
            self.logger.experiment.summary[f"{self.exp_name}/val_query_classes_acc_gain"] = query_classes_acc_gain
            self.logger.experiment.summary[f"{self.exp_name}/val_forgetting"] = forgetting
            self.logger.experiment.summary[f"{self.exp_name}/val_uniform_accuracy"] = uniform_accuracy

            # self.logger.experiment.summary[f"{self.exp_name}/val_acc_ep{self.current_epoch}"] = acc
            # self.logger.experiment.summary[
            #     f"{self.exp_name}/val_per_class_acc_ep{self.current_epoch}"] = self.per_class_val_acc.tolist()
            # self.logger.experiment.summary[
            #     f"{self.exp_name}/val_simple_weighted_accuracy_ep{self.current_epoch}"] = simple_weighted_accuracy
            # self.logger.experiment.summary[
            #     f"{self.exp_name}/val_query_classes_acc_gain_ep{self.current_epoch}"] = query_classes_acc_gain
            # self.logger.experiment.summary[f"{self.exp_name}/val_forgetting_ep{self.current_epoch}"] = forgetting

            log.info(f"{self.exp_name}/val_simple_weighted_accuracy: {simple_weighted_accuracy}")
            log.info(f"{self.exp_name}/val_uniform_accuracy: {uniform_accuracy}")
            log.info(f"{self.exp_name}/val_query_classes_acc_gain: {query_classes_acc_gain}")
            log.info(f"{self.exp_name}/val_forgetting: {forgetting}")
            log.info(f"{self.exp_name}/val_acc: {acc}")
            log.info(f"{self.exp_name}/val_per_class_acc: {self.per_class_val_acc}")

            self.log(f"{self.exp_name}/val_simple_weighted_accuracy", simple_weighted_accuracy, on_epoch=True)
            self.log(f"{self.exp_name}/val_uniform_accuracy", uniform_accuracy, on_epoch=True)
            self.log(f"{self.exp_name}/val_query_classes_acc_gain", query_classes_acc_gain,
                     on_epoch=True)
            self.log(f"{self.exp_name}/val_forgetting", forgetting, on_epoch=True)

            self.log(f"{self.exp_name}/best_val_acc_m", self.best_val_acc, on_epoch=True)  # manual
            self.log(f"{self.exp_name}/val_best_simple_weighted_accuracy", self.best_simple_weighted_accuracy,
                     on_epoch=True)
            self.log(f"{self.exp_name}/val_best_uniform_accuracy", self.best_uniform_accuracy,
                     on_epoch=True)
            self.log(f"{self.exp_name}/val_best_query_class_acc_gain", self.best_query_class_acc_gain, on_epoch=True)
            self.log(f"{self.exp_name}/val_least_forgetting", self.least_forgetting, on_epoch=True)

    def on_fit_end(self):
        if self.detailed_testing:
            # Save the epochs for the best metrics
            self.logger.experiment.summary[
                f"{self.exp_name}/best_val_acc_epoch"] = self.best_val_acc_epoch
            self.logger.experiment.summary[
                f"{self.exp_name}/val_best_query_class_acc_gain_epoch"] = self.best_query_class_acc_gain_epoch
            self.logger.experiment.summary[
                f"{self.exp_name}/val_least_forgetting_epoch"] = self.least_forgetting_epoch
            self.logger.experiment.summary[
                f"{self.exp_name}/val_best_simple_weighted_accuracy_epoch"] = self.best_simple_weighted_accuracy_epoch
            self.logger.experiment.summary[
                f"{self.exp_name}/val_uniform_accuracy_epoch"] = self.best_uniform_accuracy_epoch

            self.logger.experiment.summary[
                f"{self.exp_name}/best_val_acc"] = self.best_val_acc
            self.logger.experiment.summary[
                f"{self.exp_name}/val_best_query_class_acc_gain"] = self.best_query_class_acc_gain
            self.logger.experiment.summary[
                f"{self.exp_name}/val_least_forgetting"] = self.least_forgetting
            self.logger.experiment.summary[
                f"{self.exp_name}/val_best_simple_weighted_accuracy"] = self.best_simple_weighted_accuracy
            self.logger.experiment.summary[
                f"{self.exp_name}/val_best_uniform_accuracy"] = self.best_uniform_accuracy

            # Save the latest model state at the end of training
            latest_ckpt_path = os.path.join(self.trainer.default_root_dir, "latest.ckpt")
            latest_state_dict = {k: v.clone().detach().cpu() for k, v in self.learner_model.state_dict().items()}
            torch.save(latest_state_dict, latest_ckpt_path)

    def on_test_end(self):  # on fit end will cause issues since next client reinit the neural network weights
        self.val_acc_best.reset()
        self.best_val_acc = float('-inf')
        self.best_query_class_acc_gain = float('-inf')
        self.least_forgetting = float('-inf')
        self.best_simple_weighted_accuracy = float('-inf')
        self.best_uniform_accuracy = float('-inf')

    def on_save_checkpoint(self, checkpoint):
        checkpoint['cfg'] = OmegaConf.to_container(self.cfg, resolve=True)  # Save cfg in the checkpoint

    def on_load_checkpoint(self, checkpoint):
        if 'cfg' in checkpoint:
            self.cfg = OmegaConf.create(checkpoint['cfg'])  # Load cfg from the checkpoint
        else:
            raise KeyError("The checkpoint does not contain 'cfg' key.")
        # Ensure to load any other necessary state here

    @classmethod
    def load_from_checkpoint(cls, checkpoint_path, cfg=None, **kwargs):
        checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
        if 'cfg' not in checkpoint and cfg is not None:
            checkpoint['cfg'] = OmegaConf.to_container(cfg, resolve=True)
        elif 'cfg' in checkpoint and cfg is None:
            cfg = OmegaConf.create(checkpoint['cfg'])
        elif 'cfg' in checkpoint and cfg is not None:
            cfg = OmegaConf.merge(OmegaConf.create(checkpoint['cfg']), cfg)
        return super().load_from_checkpoint(checkpoint_path, cfg=cfg, **kwargs)

    # def calculate_client_all_accuracies(self, per_class_acc):
    #     """
    #     Calculates various custom accuracy metrics for a specified client including measures before and after knowledge transfer.
    #
    #     Parameters:
    #     - cfg: Configuration object containing client and training configurations including `learner_client` and `query_classes`.
    #     - per_class_acc (list): List of accuracies per class for the specified client after training.
    #     - data_dists_vectorized (dict): Dictionary mapping clients to their class distribution data.
    #
    #     Returns:
    #     - tuple: Contains the calculated accuracies using different strategies, including the effect of forgetting.
    #     """
    #
    #     learner_client = self.learner_client
    #
    #     client_name = f'client_{learner_client}'
    #     data_dists_vectorized = self.data_dists_vect
    #     class_distribution = data_dists_vectorized[client_name]
    #     query_classes = self.goal_class
    #
    #     # Initialize variables for each strategy
    #     total_uniform_acc = total_weighted_acc = total_query_class_acc = total_local_class_acc = 0
    #     count_uniform_classes = total_weight = total_local_weight = 0
    #
    #     num_classes = len(per_class_acc)
    #     print(f"num_classes: {num_classes}")
    #     query_class_acc = [per_class_acc[i] for i in query_classes]
    #
    #     # Fetch pre-knowledge transfer accuracies
    #     train_run_id = self.run_id
    #     api = wandb.Api()
    #     train_run = api.run(train_run_id)
    #     train_run_summary = train_run.summary._json_dict
    #     pre_transfer_acc = train_run_summary[f'client-{learner_client}/per_class_test_acc']
    #
    #     query_class_acc_gain = [(per_class_acc[i] - pre_transfer_acc[i]) for i in query_classes]
    #
    #     for cls_index in range(num_classes):
    #         if class_distribution[cls_index] > 0 or cls_index in query_classes:
    #             # Calculate uniform accuracy
    #             total_uniform_acc += per_class_acc[cls_index]
    #             count_uniform_classes += 1
    #
    #             # Simple weighted accuracy (query classes get weight 1)
    #             weight = 1 if cls_index in query_classes else class_distribution[cls_index] / sum(class_distribution)
    #             total_weighted_acc += weight * per_class_acc[cls_index]
    #             total_weight += weight
    #
    #             # Local classes accuracy excluding query classes
    #             if cls_index not in query_classes:
    #                 local_weight = class_distribution[cls_index] / sum(class_distribution)
    #                 total_local_class_acc += local_weight * per_class_acc[cls_index]
    #                 # print(f"local_weight:{local_weight}")
    #                 total_local_weight += local_weight
    #
    #     # Compute accuracies for each strategy
    #     uniform_accuracy = total_uniform_acc / count_uniform_classes if count_uniform_classes > 0 else 0
    #     simple_weighted_accuracy = total_weighted_acc / total_weight if total_weight > 0 else 0
    #     query_classes_accuracy = sum(query_class_acc) / len(query_classes) if query_classes else 0
    #     query_classes_acc_gain = sum(query_class_acc_gain) / len(query_class_acc_gain) if query_classes else 0
    #     local_classes_accuracy = total_local_class_acc / total_local_weight if total_local_weight > 0 else 0
    #     forgetting = sum((per_class_acc[j] - pre_transfer_acc[j]) for j in range(len(per_class_acc)) if
    #                      (per_class_acc[j] - pre_transfer_acc[j]) < 0) / len(
    #         [accuracy for accuracy in pre_transfer_acc if accuracy > 0])
    #
    #     print(f"uniform_accuracy: {uniform_accuracy}")
    #     print(f"simple_weighted_accuracy: {simple_weighted_accuracy}")
    #     print(f"query_classes_accuracy: {query_classes_accuracy}")
    #     print(f"query_classes_acc_gain: {query_classes_acc_gain}")
    #     print(f"local_classes_accuracy: {local_classes_accuracy}")
    #     print(f"forgetting: {forgetting}")
    #
    #     return (
    #         uniform_accuracy, simple_weighted_accuracy, query_classes_accuracy, query_classes_acc_gain,
    #         local_classes_accuracy,
    #         forgetting)

    # def identify_binary_teacher_weights(self, client_models, threshold= 0.05):
    #     """
    #     Identifies binary weights of each teacher for each class.
    #     A weight of 0 indicates underrepresented classes, and a weight of 1 indicates meaningful presence.
    #     """
    #     binary_teacher_weights = []
    #
    #     for client_id, model in enumerate(client_models):
    #         synthetic_data, synthetic_labels = self.create_anonymized_data_impressions(model)
    #         synthetic_data = synthetic_data.to(next(model.parameters()).device)
    #         output = model(synthetic_data)
    #         predicted_probs = F.softmax(output, dim=1)
    #
    #         teacher_weights = []
    #         avg_predicted_probs = []
    #         for cls in range(predicted_probs.size(1)):
    #             avg_prob = predicted_probs[:, cls].mean().item()
    #             avg_predicted_probs.append(avg_prob)
    #             if avg_prob >= threshold:
    #                 teacher_weights.append(1)
    #             else:
    #                 teacher_weights.append(0)
    #         binary_teacher_weights.append(teacher_weights)
    #
    #         print(f"client's {client_id} avg_predicted_probs: {avg_predicted_probs}")
    #
    #     # Logging the binary_teacher_weights for debugging
    #     for teacher_id, weights in enumerate(binary_teacher_weights):
    #         log.info(f"Teacher {teacher_id} binary weights: {weights}")
    #
    #     return binary_teacher_weights

    def identify_binary_teacher_weights(self, client_models, num_classes=10):
        binary_teachers_weights = []
        for i, model in enumerate(client_models):
            print(f"\nAnalyzing client {i}'s model")
            client_name = f"client_{i}"

            # Identify underrepresented classes
            print(f"Detected stats:")
            binary_teacher_weights = self.detect_stats(model, threshold=self.noise_threshold)
            binary_teachers_weights.append(binary_teacher_weights)
        return binary_teachers_weights

    def detect_stats(self, model, threshold=0.01):
        synthetic_data, synthetic_labels = self.create_anonymized_data_impressions(model)
        synthetic_data = synthetic_data.to(next(model.parameters()).device)  # Move data to the same device as the model
        output = model(synthetic_data)
        predicted_probs = F.softmax(output, dim=1)

        underrepresented_classes = []
        avg_predicted_probs = []
        binary_teacher_weights = []
        for i in range(predicted_probs.size(1)):
            avg_prob = predicted_probs[:, i].mean().item()
            avg_predicted_probs.append(avg_prob)
            if avg_prob < threshold:
                underrepresented_classes.append(i)
                binary_teacher_weights.append(0)
            else:
                binary_teacher_weights.append(1)

        print(f"model's avg_predicted_probs: {avg_predicted_probs}")
        print(f"model's underrepresented_classes: {underrepresented_classes}")
        print(f"model's binary_teacher_weights: {binary_teacher_weights}")

        return binary_teacher_weights

    def create_anonymized_data_impressions(self, model, num_classes=10):

        synthetic_data = []
        synthetic_labels = []
        for i in range(20):
            # Generate random noise input matching the input shape expected by the model
            input_shape = (3, 224, 224)  # Adjust this shape based on the model's expected input
            noise_input = torch.randn(input_shape)
            synthetic_data.append(noise_input)

        return torch.stack(synthetic_data), torch.tensor(synthetic_labels, dtype=torch.long)

# Marking this version!
# command example (With Data_free_alpha3 goal_class_boost1.5, after being edited to include the learner):
# python QKT/transfer_knowledge_QKT_detailed_2.py teacher_client=[0,1,2,3,4,5,6,7,8,9] learner_client=[0,1,2,3,4,5,6,7,8,9] optim.optim.lr=0.001 trainer.max_epochs=25 KL_temperature=1 KL_loss_strength=-1 logger.group=qkt_mask train_exp_id=nalballa25/qkt/673oypii logger.tags=[num_classes_to_select1,two_stages_KD_fc_exp,max25ep,train_673oypii,all_teachers,alpha_data_free3,goal_class_boost1.5,edited] validate_and_test=True qkt_multi_teachers=True detailed_testing=True two_stage_qkt=True two_stage_starting_point=qkt num_classes_to_select=1 optim.scheduler=False seed=2024 alpha_data_free=True goal_class_boost=1.5