from torchmetrics import MaxMetric
from torchmetrics import Accuracy
from torchmetrics import ConfusionMatrix
import pytorch_lightning as pl
import torch
from torch import nn
from omegaconf import OmegaConf
from hydra.utils import instantiate
import numpy as np
import wandb
from pathlib import Path
import os
import logging
import wandb
import copy


log = logging.getLogger(__name__)

class N_ClassificationModel(pl.LightningModule):
    def __init__(self, cfg,learner_model: nn.Module=None):
        super().__init__()
        self.save_hyperparameters(cfg)
        # print(OmegaConf.to_yaml(self.hparams))
        self.learner_model = learner_model
        # self.num_classes = cfg.model.model.num_classes


        self.val_targets = []
        # self.halfway_model: nn.Module = None

        # self.model = instantiate(self.hparams.model.model)

        if learner_model:
            self.model : nn.Module = learner_model
            self.model.requires_grad_(requires_grad=True)
            self.model.train()
            log.info("using the provided model as a starting point")

        else:
            self.model = instantiate(self.hparams.model)
            log.info("initializing a new model..")

        print(type(self.model))

        # Current client number
        self.current_client_idx = cfg.learner_client

        # loss function
        self.loss = nn.CrossEntropyLoss()

        # use separate metric instance for train, val and test step
        # to ensure a proper reduction over the epoch
        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()

        self.model_test_acc = 0
        self.model_val_acc = 0

        self.per_class_test_acc = []
        self.per_class_val_acc = []


        if cfg.visionClassification:
            self.num_classes = cfg.model.num_classes
            log.info(f"self.num_classes = {self.num_classes}")  # TODO: remove
            self.train_confusion_matrix = ConfusionMatrix(num_classes=self.num_classes)  # TODO: fix this to be filled automatically
            self.val_confusion_matrix = ConfusionMatrix(num_classes=self.num_classes)  # TODO: fix this to be filled automatically
            self.test_confusion_matrix = ConfusionMatrix(num_classes=self.num_classes)  # TODO: fix this to be filled automatically

        # for logging best so far validation accuracy
        self.val_acc_best = MaxMetric()
        self.round = cfg.round
        self.currentE = cfg.currentE
        self.track_round = cfg.track_round

    def next_client(self):
        # path = Path(f'{os.getcwd()}/models/')
        # path.mkdir(parents=True, exist_ok=True)
        # file_path = path / f'client-{self.current_client_idx}.pth'  # best_val_acc-{self.val_acc_best.compute()}
        # torch.save(
        #     self.model.state_dict(),
        #     file_path
        # )
        # log.info(f"Saved client-{self.current_client_idx} at {path} locally")
        # artifact = wandb.Artifact(name=f"clients_models", type="model")
        # artifact.add_file(str(file_path), name=f'client-{self.current_client_idx}.pth')
        # log.info(f"Saved client-{self.current_client_idx} as an Artifact at WANDB")
        self.current_client_idx += 1
        # reinit
        # self.model = instantiate(self.hparams.model.model)


    def forward(self, x):
        return self.model(x)

    def step(self, batch):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        return loss, preds, y

    def training_step(self, batch, batch_idx):
        loss, preds, targets = self.step(batch)

        # log train metrics
        acc = self.train_acc(preds, targets)

        # print("targets are:", targets)
        # print("preds are:", preds)

        # conf_mat = self.train_confusion_matrix(preds, targets)
        self.log(f"train_client-{self.current_client_idx}-round{self.round}/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log(f"train_client-{self.current_client_idx}-round{self.round}/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        # self.log(
        #     f"train_client-{self.current_client_idx}/confusion_matrix",
        #     conf_mat, on_step=False, on_epoch=True, prog_bar=False
        # )
        # we can return here dict with any tensors
        # and then read it in some callback or in `training_epoch_end()`` below
        # remember to always return loss from `training_step()` or else backpropagation will fail!
        return {"loss": loss, "preds": preds, "targets": targets}

    def validation_step(self, batch, batch_idx):
        loss, preds, targets = self.step(batch)
        # log.info(f"preds: {preds}, targets: {targets}") #TODO: remove this
        self.val_confusion_matrix(preds, targets)
        self.val_targets = targets  # to check in case fair vs. unfair validation #TODO: remove this

        # log val metrics
        acc = self.val_acc(preds, targets)
        log.info(f"val_acc: {acc}")
        # conf_mat = self.val_confusion_matrix(preds, targets)
        self.log(f"val_client-{self.current_client_idx}/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
        self.log(f"val_client-{self.current_client_idx}/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
        # self.log(
        #     f"val_client-{self.current_client_idx}/confusion_matrix",
        #     conf_mat, on_step=False, on_epoch=True, prog_bar=False
        # )
        return {"loss": loss, "preds": preds, "targets": targets}

    def validation_epoch_end(self, outputs):
        self.acc = self.val_acc.compute()  # get val accuracy from current epoch
        self.val_acc_best.update(self.acc)

        confusion_matrix = self.val_confusion_matrix.compute()
        confusion_matrix = confusion_matrix / confusion_matrix.sum(axis=1)
        self.per_class_val_acc = np.diag(confusion_matrix.cpu().detach().numpy())
        self.model_val_acc = self.acc

        self.log(f"val_client-{self.current_client_idx}/acc_best", self.val_acc_best.compute(), on_epoch=True,
                 prog_bar=True)

        log.info(f"per_class_val_acc = {self.per_class_val_acc}")
        log.info(f">> val_acc = {self.acc}")
        log.info(f"val targets = {self.val_targets}") # last validation batch, to compare TODO: remove this


        # conf matrix
        self.logger.experiment.log(
            {
                f"val_client-{self.current_client_idx}-round{self.round}/confusion_matrix": wandb.plot.confusion_matrix(
                    probs=None,
                    y_true=np.concatenate([output['targets'].cpu().numpy() for output in outputs]).ravel(),
                    preds=np.concatenate([output['preds'].cpu().numpy() for output in outputs]).ravel(),
                    class_names=None
                )
            },
            # step=self.global_step,
            commit=False
        )

        self.logger.experiment.summary[
            f"client-{self.hparams.learner_client}/val_per_class_acc"
        ] = self.per_class_val_acc

        self.val_acc.reset()
        self.val_confusion_matrix.reset()




    def test_step(self, batch, batch_idx):
        loss, preds, targets = self.step(batch)

        self.test_confusion_matrix(preds, targets)

        # log test metrics
        acc = self.test_acc(preds, targets)

        # conf_mat = self.test_confusion_matrix(preds, targets)
        self.log(f"test_client-{self.current_client_idx}-round{self.round}/loss", loss, on_step=False, on_epoch=True)
        self.log(f"test_client-{self.current_client_idx}-round{self.round}/acc", acc, on_step=False, on_epoch=True)
        self.log(f"client-{self.current_client_idx}_best-test-acc-round{self.round}", acc, on_step=False, on_epoch=True)

        # self.log(
        #     f"test_client-{self.current_client_idx}/confusion_matrix",
        #     conf_mat, on_step=False, on_epoch=True, prog_bar=False
        # )

        return {"loss": loss, "preds": preds, "targets": targets}

    def test_epoch_end(self, outputs):
        acc = self.test_acc.compute()
        self.model_test_acc = acc
        confusion_matrix = self.test_confusion_matrix.compute()

        confusion_matrix = confusion_matrix / confusion_matrix.sum(axis=1)
        self.per_class_test_acc = np.diag(confusion_matrix.cpu().detach().numpy())

        self.logger.experiment.log(
            {
                f"test_client-{self.current_client_idx}-round{self.round}/confusion_matrix": wandb.plot.confusion_matrix(
                    probs=None,
                    y_true=np.concatenate([output['targets'].cpu().numpy() for output in outputs]).ravel(),
                    preds=np.concatenate([output['preds'].cpu().numpy() for output in outputs]).ravel(),
                    class_names=None
                )
            },
            # step=self.global_step,
            commit=False
        )

        self.logger.experiment.summary[
            f"client-{self.current_client_idx}-round{self.round}/test_per_class_test_acc"
        ] = self.per_class_test_acc

        log.info(
            f"client-{self.current_client_idx}-round{self.round}/test_per_class_test_acc: {self.per_class_test_acc}")

        self.loggers[-1].log_hyperparams({"test_acc": acc.cpu().item()})
        self.test_acc.reset()
        self.test_confusion_matrix.reset()

    def on_epoch_end(self):
        self.test_acc.reset()
        self.val_acc.reset()
        self.train_acc.reset()
        # if self.track_round:  #TODO ty to fix the issue here
        #     if (self.current_epoch+1) == (self.currentE/2):
        #         self.halfway_model = copy.deepcopy(self.model)
        #         log.info(f">> Currently in the halfway local epoch. Sending the halfway model... ")


        # # reset metrics at the end of every epoch
        # self.train_acc.reset()
        # # self.test_acc.reset()
        # self.val_acc.reset()
        # # self.train_confusion_matrix.reset()
        # # self.val_confusion_matrix.reset()
        # # self.test_confusion_matrix.reset()

    # def on_test_end(self): # on fit end will cause issues since next client reinit the neural network weights
    #     # self.next_client()
    #     self.val_acc_best.reset()

    def configure_optimizers(self):
        optim = instantiate(config=self.hparams.optim.optim, params=self.model.parameters())
        return optim

