from typing import Tuple
from black import List
import torch
from torch import Tensor
import pytorch_lightning as pl
import torch.nn.functional as F
import torchvision
from models.base_model import (
    ShapesBaseModel,
)
from typing import Optional, Dict, Any


class ShapesModule(ShapesBaseModel):
    """Model for training / evaluating shapes.

    Args:
        pretrained: use a pretrained model
        num_classes: number of classes to use
        freeze: if true, only a linear classification head is trained
        top_k: list of k values for computing top_k accuracy.
            Example: [1, 5] -> top_1 and top_5 accuracies are computed
    """

    def __init__(
        self,
        learning_rate: float = 1e-1,
        optimizer: str = "adam",
        momentum: float = 0.9,
        weight_decay: float = 1e-4,
        pretrained: bool = True,
        top_k: Tuple[int, ...] = (1, 10),
        num_classes: Optional[int] = None,
        datamodule: Optional[pl.LightningDataModule] = None,
    ):
        super().__init__(top_k=top_k, num_classes=num_classes, datamodule=datamodule)

        self.learning_rate = learning_rate
        self.optimizer = optimizer
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.pretrained = pretrained

        self.backbone = self.load_backbone()
        self.model = self.backbone

    def load_backbone(self):
        assert self.num_classes != 1000, "num classes for a ResNet must be 1000"
        resnet = torchvision.models.resnet50(
            pretrained=self.pretrained, num_classes=1000
        )
        return resnet

    def forward(self, x):
        return self.model(x)


class ShapesFineTuneModule(ShapesModule):
    """Finetunes entire model weights using a fresh linear classifier head"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.linear_classifier = torch.nn.Linear(2048, self.num_classes)

    def load_backbone(self):
        resnet = torchvision.models.resnet50(
            pretrained=self.pretrained, num_classes=1000
        )
        model = torch.nn.Sequential(
            *(list(resnet.children())[:-1] + []),
        )
        return model

    def forward(self, x):
        feats = self.backbone(x).squeeze(-1).squeeze(-1)
        out = self.linear_classifier(feats)
        return out


class ShapesLinearModule(ShapesModule):
    """Trains a fresh linear classifier while keeping other weights frozen."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.linear_classifier = torch.nn.Linear(2048, self.num_classes)

    def load_backbone(self):
        resnet = torchvision.models.resnet50(
            pretrained=self.pretrained, num_classes=1000
        )
        model = torch.nn.Sequential(
            *(list(resnet.children())[:-1] + []),
        )
        return model

    def on_train_epoch_start(self) -> None:
        self.backbone.eval()

    def forward(self, x):
        with torch.no_grad():
            feats = self.backbone(x).squeeze(-1).squeeze(-1)
        out = self.linear_classifier(feats)
        return out


class ShapesLinearModulePredictions(ShapesLinearModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.results: Dict[str, Any] = dict()
        self.reset_results()

    def reset_results(self):
        # create a dictionary for each stage in the test_loader_names
        # TODO weirdly the line `self.results = dict.fromkeys(loader_types, dict())` wrote both sub-dict at the same adress (overwritting each stage)
        for stage in self.test_loader_names:
            self.results[stage] = {}

    def shared_step(self, batch: Tensor, stage: str = "train"):
        x, y, attributes = batch
        y_hat = self(x)
        batch_size = x.shape[0]
        loss = F.cross_entropy(y_hat, y)
        self.log(f"{stage}_loss", loss, sync_dist=True, batch_size=batch_size)
        for k in self.top_k:
            accuracy_metric = getattr(self, f"{stage}_top_{k}_accuracy")
            accuracy_metric(F.softmax(y_hat, dim=-1), y)
            self.log(
                f"{stage}_top_{k}_accuracy",
                accuracy_metric,
                prog_bar=True,
                sync_dist=True,
                on_epoch=True,
                on_step=False,
                # loader names are used instead
                add_dataloader_idx=False,
                batch_size=batch_size,
            )
        pose_x = attributes["pose"][0][:, None]
        pose_y = attributes["pose"][1][:, None]
        pose_z = attributes["pose"][2][:, None]
        image_path = attributes["image_path"]
        poses = torch.cat([pose_x, pose_y, pose_z], dim=1)

        self.save_step_results(y, y_hat, poses, image_path, stage)

        return loss

    def save_step_results(
        self, y: Tensor, y_hat: Tensor, poses: Tensor, image_path, stage: str
    ) -> None:
        self._save_value(y, "y", stage)
        self._save_value(y_hat, "y_hat", stage)
        self._save_value(poses, "fov", stage)
        self._save_img_path(image_path, "image_path", stage)

    def _save_img_path(self, value: List, name: str, stage: str):
        """Saving function specific to list of str"""
        if name in self.results[stage]:
            self.results[stage][name] += value
        else:
            self.results[stage][name] = value

    def _save_value(self, value: Tensor, name: str, stage: str):
        if name in self.results[stage]:
            self.results[stage][name] = torch.cat(
                [torch.tensor(self.results[stage][name]), value.cpu()]
            ).tolist()
        else:
            self.results[stage][name] = value.cpu().tolist()
