from abc import ABC
from typing import Optional, Dict

import torch
from omegaconf import DictConfig
from torch import nn
from torchmetrics.classification import MulticlassAccuracy

from avr.task.avr_module import AVRModule


class VasrModule(AVRModule, ABC):
    def __init__(self, cfg: DictConfig):
        super().__init__(cfg)
        create_metrics = lambda: nn.ModuleDict(
            {
                "acc": nn.ModuleDict(
                    {
                        "target": MulticlassAccuracy(num_classes=cfg.num_answers),
                    }
                )
            }
        )
        self.metrics = nn.ModuleDict(
            {
                "tr": create_metrics(),
                "val": create_metrics(),
                "test": create_metrics(),
            }
        )
        self.target_pred_head = nn.Sequential(
            nn.Linear(cfg.avr.model.embedding_size, cfg.avr.model.embedding_size),
            nn.ReLU(inplace=True),
            nn.Linear(cfg.avr.model.embedding_size, 1),
            nn.Flatten(-2, -1),
        )
        self.target_loss = nn.CrossEntropyLoss()

    def _step(self, split: str, batch, batch_idx: int) -> Dict[str, torch.Tensor]:
        context, answers, y = batch
        batch_size = len(y)
        embedding = self.model(context, answers)

        y_hat = self.target_pred_head(embedding)
        loss = self.target_loss(y_hat, y)
        self.logm(loss, "loss", split, batch_size=batch_size)

        acc = self.metrics[split]["acc"]["target"](y_hat, y)
        self.logm_type(acc, "acc", split, "target", batch_size=batch_size)

        return {"loss": loss, "acc": acc}

    def training_step(self, batch, batch_idx) -> Dict[str, torch.Tensor]:
        return self._step("tr", batch, batch_idx)

    def validation_step(
            self, batch, batch_idx: int, dataloader_idx: Optional[int] = None
    ) -> Dict[str, torch.Tensor]:
        return self._step("val", batch, batch_idx)

    def test_step(
            self, batch, batch_idx: int, dataloader_idx: Optional[int] = None
    ) -> Dict[str, torch.Tensor]:
        return self._step("test", batch, batch_idx)
