import logging
from typing import Optional

import torch
import torch.nn.functional as F
from pytorch_lightning import LightningModule, seed_everything
from torch import nn
from torchmetrics import Accuracy, FBetaScore, MetricCollection

log = logging.getLogger(__name__)


class SVCModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        return F.one_hot(torch.as_tensor(self.model.predict(x.cpu().numpy()))).to(x.device)


class LambdaModule(nn.Module):
    def __init__(self, lambda_func) -> None:
        super().__init__()

        self.lambda_func = lambda_func

    def forward(self, x: torch.Tensor):
        return self.lambda_func(x)


class Classifier(LightningModule):
    def __init__(
        self,
        input_dim: int,
        num_classes: int,
        lr: float,
        deep: bool,
        seed: int,
        aggregation_module: nn.Module,
        bias: bool = True,
        x_feature: str = "x",
        y_feature: str = "y",
        first_activation: nn.Module = nn.Tanh(),
        second_activation: nn.Module = nn.ReLU(),
        first_projection_dim: Optional[int] = None,
    ):
        super().__init__()

        if not deep and (first_activation is None or second_activation is None or first_projection_dim is None):
            log.warning("If deep is False, first_activation, second_activation and first_projection_dim are not used!")

        seed_everything(seed)

        if callable(first_activation) and getattr(first_activation, "__name__", None) == "<lambda>":
            first_activation = LambdaModule(first_activation)

        if callable(second_activation) and getattr(second_activation, "__name__", None) == "<lambda>":
            second_activation = LambdaModule(second_activation)

        first_projection_dim = input_dim if first_projection_dim is None else first_projection_dim

        self.relative_block = aggregation_module

        self.class_proj = (
            nn.Sequential(
                #
                nn.Linear(input_dim, first_projection_dim, bias=bias),
                first_activation,
                #
                nn.Linear(first_projection_dim, first_projection_dim // 2, bias=bias),
                second_activation,
                #
                nn.Linear(first_projection_dim // 2, num_classes, bias=bias),
            )
            if deep
            else nn.Sequential(nn.Linear(input_dim, num_classes))
        )

        self.train_metrics = MetricCollection(
            {
                "accuracy": Accuracy(task="multiclass", num_classes=num_classes, top_k=1),
                "f1": FBetaScore(task="multiclass", num_classes=num_classes, top_k=1),
            }
        )
        self.val_metrics = self.train_metrics.clone()
        self.test_metrics = self.train_metrics.clone()

        self.lr: float = lr

        self.x_feature: str = x_feature
        self.y_feature: str = y_feature

    def get_attention_weights(self, encoded_x: torch.Tensor, attention_idx: int):
        return self.relative_block.get_attention_weights(encoded_x, attention_idx=attention_idx)

    def encode(self, x):
        x = self.relative_block.encode(x)
        return x

    def decode(self, x):
        x = self.relative_block.decode(x)
        x = self.class_proj(x)
        return F.log_softmax(x, dim=1)

    def forward(self, x):
        out = self.relative_block(x)
        out = self.class_proj(out)

        return F.log_softmax(out, dim=1)

    def _step(self, batch, split: str):
        logits = self(batch[self.x_feature])
        loss = F.cross_entropy(logits, batch[self.y_feature])
        preds = torch.argmax(logits, dim=1)
        metrics = getattr(self, f"{split}_metrics")
        metrics.update(preds, batch[self.y_feature])

        self.log(f"{split}_loss", loss, prog_bar=True)
        self.log_dict(metrics, prog_bar=True)

        return loss

    def training_step(self, batch, batch_idx):
        return self._step(batch=batch, split="train")

    def validation_step(self, batch, batch_idx):
        return self._step(batch=batch, split="val")

    def test_step(self, batch, batch_idx):
        return self._step(batch=batch, split="test")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
