import logging
import os
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import pytorch_lightning as pl
import torch
import torch.utils.data
import torchmetrics
from torch import Tensor, nn, optim
from torch.optim import lr_scheduler as torch_lr_scheduler
from torchvision import transforms
from torchvision.models.feature_extraction import create_feature_extractor

from src import config

logger = logging.getLogger(__name__)


class BaseModelWrapper(pl.LightningModule):
    def __init__(
        self,
        model: torch.nn.Module = None,
        features_nodes: Optional[List[str]] = None,
        input_dim: Tuple = None,
        test_transforms: Optional[Type[transforms.Compose]] = None,
        train_transforms: Optional[Type[transforms.Compose]] = None,
        download: bool = True,
        url: Optional[str] = None,
        root: str = config.CHECKPOINTS_DIR,
        model_name: Optional[str] = None,
        criterion: Type[nn.modules.loss._Loss] = nn.CrossEntropyLoss(),
        optimizer_cls: Type[optim.Optimizer] = optim.SGD,
        lr_scheduler_cls: Type[
            torch_lr_scheduler._LRScheduler
        ] = torch_lr_scheduler.CosineAnnealingLR,
        optimizer_kwargs: Dict[str, Any] = {},
        lr_scheduler_kwargs: Dict[str, Any] = {},
        *args: Any,
        **kwargs: Any,
    ) -> None:
        super().__init__(*args, **kwargs)
        self.model = model
        self.input_dim = input_dim
        self.features_nodes = features_nodes
        self.test_transforms = test_transforms
        self.train_transforms = train_transforms
        self.download = download
        self.url = url
        self.root = root
        self.model_name = model_name

        self.criterion = criterion
        self.optimizer_cls = optimizer_cls
        self.lr_scheduler_cls = lr_scheduler_cls

        self.example_input_array = (
            torch.randn(1, *self.input_dim).to(self.device, dtype=torch.float32)
            if self.input_dim is not None
            else None
        )

        self.save_hyperparameters(ignore=["model", "criterion"])
        self.__post_init__()

    def __post_init__(self):
        self.setup()

        if self.features_nodes is not None:
            self.feature_extractor = create_feature_extractor(
                self.model, self.features_nodes
            )
        else:
            self.feature_extractor = None

    def setup(self, stage=None):
        # download pre trained weights etc
        model_dir = os.path.join(self.root, self.model_name)

        os.makedirs(model_dir, exist_ok=True)
        if self.url is not None:
            parts = torch.hub.urlparse(self.url)
            filename = os.path.basename(parts.path)

            # download from url
            cached_file = os.path.join(model_dir, filename)
            if not os.path.exists(cached_file) and self.download:
                logger.info('Downloading: "{}" to {}\n'.format(self.url, cached_file))
                torch.hub.download_url_to_file(self.url, cached_file)

            # load from memory
            cached_file = os.path.join(model_dir, parts.path)
            if os.path.exists(cached_file):
                logger.info("loading weights from chached file: %s", cached_file)
                w = torch.load(cached_file, map_location=self.device)
                self.model.load_state_dict(w)
            else:
                raise FileNotFoundError("Cached file not found: {}".format(cached_file))

    def on_fit_start(self):
        if self.input_dim is not None:
            tensorboard_logger = self.logger.experiment

            prototype_array = torch.randn(1, *self.input_dim).to(self.device)
            tensorboard_logger.add_graph(self.model, prototype_array)

    def _batch_extraction(self, batch):
        x, y = batch
        return x, y

    def training_step(self, batch, batch_idx):
        x, y = self._batch_extraction(batch)
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        acc = torchmetrics.functional.accuracy(y_hat, y)
        results = {"train/loss": loss, "train/acc": acc}
        self.log_dict(
            results,
            prog_bar=True,
            on_epoch=True,
            on_step=False,
            rank_zero_only=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = self._batch_extraction(batch)
        y_hat = self.forward(x)
        val_loss = self.criterion(y_hat, y)
        acc = torchmetrics.functional.accuracy(y_hat, y)
        results = {"val/loss": val_loss, "val/acc": acc}
        self.log_dict(
            results,
            prog_bar=True,
            on_epoch=True,
            on_step=False,
            rank_zero_only=True,
        )
        return results

    def test_step(self, batch, batch_idx):
        x, y = self._batch_extraction(batch)
        y_hat = self.forward(x)
        test_loss = self.criterion(y_hat, y)
        acc = torchmetrics.functional.accuracy(y_hat, y)
        results = {"test/loss": test_loss, "test/acc": acc}
        self.log_dict(
            results,
            prog_bar=True,
            on_epoch=True,
            on_step=False,
            rank_zero_only=True,
        )
        return results

    def predict_step(self, batch, batch_idx):
        x, y = self._batch_extraction(batch)
        preds = self.forward(x)

        return {"logits": preds, "targets": y}

    def forward(self, x: Tensor, *args: Any, **kwargs: Any) -> Tensor:
        return self.model(x, *args, **kwargs)

    def extract_features(self, x: Tensor) -> Dict[str, Tensor]:
        try:
            self.feature_extractor.eval()
            with torch.no_grad():
                return self.feature_extractor(x)
        except AttributeError as exc:
            raise exc.name(
                "Feature extractor is not defined. Please define the `features_nodes` argument."
            )

    def configure_optimizers(self):
        optimizer = self.optimizer_cls(
            filter(lambda p: p.requires_grad, self.parameters()),
            **self.hparams.optimizer_kwargs,
        )
        if self.lr_scheduler_cls is not None:
            lr_scheduler = self.lr_scheduler_cls(
                optimizer, **self.hparams.lr_scheduler_kwargs
            )
            return [optimizer], [lr_scheduler]
        return [optimizer]

    def freeze_bn(self):
        # disable bn
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
                m.track_running_stats = False
