from typing import Optional, Union
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
import torchmetrics

from ..inference_recording import to_inference_record

ACCURACY_METRIC = "accuracy"
LOSS_METRIC = "loss"
CLASSWISE_ACCURACY_METRIC = "classwise_accuracy"
CONFUSION_MATRIX_METRIC = "confusion_matrix"

class SupervisedLearning(pl.LightningModule):

    def __init__(
        self,
        model: torch.nn.Module,
        test_metric_names: list[str] = [],
        classes: Union[int, list[str], None] = None,
    ) -> None:
        super().__init__()

        self.model = model
        # self.test_metrics = test_metrics

        self.train_accuracy = torchmetrics.Accuracy()
        self.test_accuracy = torchmetrics.Accuracy()

        test_metrics: dict[str, torchmetrics.Metric] = {
            "accuracy": torchmetrics.Accuracy()
        }
        for metric_name in test_metric_names:
            test_metrics[metric_name] = _create_metric(metric_name, classes)
        self.test_metrics = torchmetrics.MetricCollection(test_metrics)

    # TODO: annotate the batch type
    def training_step(self, batch) -> torch.Tensor:
        output = to_inference_record(self.model(batch.input))
        loss = F.cross_entropy(output.output, batch.target)
        self.log(
            "loss/train",
            loss,
            on_step=True,
            on_epoch=True,
        )

        self.train_accuracy(output.output, batch.target)
        self.log(
            "accuracy/train",
            self.train_accuracy,
            on_step=True,
            on_epoch=True,
        )

        return loss

    # def training_epoch_end(self) -> None:
    #     self.log("accuracy/train_epoch", self.train_accuracy)

    def validation_step(self, batch, batch_idx: int) -> torch.Tensor:
        # x, y = batch.x, batch.y
        output = to_inference_record(self.model(batch.input))
        loss = F.cross_entropy(output.output, batch.target)
        self.log("loss/val", loss)
        return loss

    def test_step(self, batch, batch_idx: int) -> None:
        # x, y = batch.x, batch.y
        output = to_inference_record(self.model(batch.input))
        loss = F.cross_entropy(output.output, batch.target)
        # self.log(
        #     "loss",
        #     loss,
        #     on_step=False,
        #     on_epoch=True,
        # )

        self.test_metrics(output.output, batch.target)
        # self.log("testmetrics", self.test_metrics)
        # self.test_metrics(output.output, batch.target)
        # self.log("testmetrics", self.test_metrics(output.output, batch.target))

        # self.test_accuracy(output.output, batch.target)
        # self.log(
        #     "accuracy",
        #     self.test_accuracy,
        #     on_step=False,
        #     on_epoch=True,
        # )

        # # print("output dim:", output.output.shape)

        # for metric_name in self.test_metrics:
        #     metric = getattr(self, metric_name)
        #     metric(output.output, batch.target)
        #     self.log(metric_name, metric, on_step=False, on_epoch=True)

    # def test_epoch_end(self, res: list) -> None:
    #     self.test_metrics.compute()

    # TODO: make this configurable
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=0.001)

def _create_metric(
    metric_name: str, classes: Union[int, list[str], None] = None,
) -> torchmetrics.Metric:
    if classes is None:
        raise ValueError("You need to specify the number of classes")
    num_classes = classes if isinstance(classes, int) else len(classes)
    if metric_name == CONFUSION_MATRIX_METRIC:
        metric = torchmetrics.ConfusionMatrix(num_classes)
    elif metric_name == CLASSWISE_ACCURACY_METRIC:
        metric = torchmetrics.Accuracy(
            average="none", num_classes=num_classes
        )
    else:
        raise ValueError(f"Unsupported metric {metric_name}")
    if isinstance(classes, list):
        metric = torchmetrics.ClasswiseWrapper(metric, classes)
    return metric
