import abc
import math

import torch
import influence.model_fast
import influence.train_fast


class TorchArchitecture(object, metaclass=abc.ABCMeta):
    @property
    @abc.abstractmethod
    def pad_amount(self) -> int:
        pass

    @property
    @abc.abstractmethod
    def model_dtype(self) -> torch.dtype:
        pass

    @property
    @abc.abstractmethod
    def one_hot_targets(self) -> bool:
        pass

    @abc.abstractmethod
    def build_model(self, train_images: torch.Tensor, device: torch.device) -> torch.nn.Module:
        pass

    @abc.abstractmethod
    def train(
        self,
        train_images: torch.Tensor,
        train_targets: torch.Tensor,
        crop_size: int,
        net: torch.nn.Module,
        drop_last_batch: bool,  # usually False for influence, True for evals; may be ignored
        generator: torch.Generator = None,
    ) -> torch.nn.Module:
        pass

    @abc.abstractmethod
    def influence_modules(self) -> list[str]:
        pass

    @abc.abstractmethod
    def representation_module(self) -> str:
        pass


def lecun_trunc_normal_(tensor: torch.Tensor) -> None:
    fan_in = tensor.shape[1]  # weight shape is (out, in)
    std = 1.0 / math.sqrt(fan_in)
    torch.nn.init.trunc_normal_(tensor, mean=0.0, std=std, a=-2 * std, b=2 * std)


class MLP(torch.nn.Module):
    """MLP that aims to mimic the JAX MLP architecture, including initialization."""

    def __init__(self, input_dim: int, width: int):
        super().__init__()
        self.fc1 = torch.nn.Linear(input_dim, width, bias=True)
        self.fc2 = torch.nn.Linear(width, 10, bias=True)

        self.reset_parameters()

    def reset_parameters(self) -> None:
        lecun_trunc_normal_(self.fc1.weight)
        if self.fc1.bias is not None:
            torch.nn.init.zeros_(self.fc1.bias)
        lecun_trunc_normal_(self.fc2.weight)
        if self.fc2.bias is not None:
            torch.nn.init.zeros_(self.fc2.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.reshape((x.shape[0], -1))  # flatten
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x


class MLPTorchArchitecture(TorchArchitecture):
    """Mimics the JAX MLP architecture, including initialization and training."""

    def __init__(
        self,
        mlp_width: int,
        learning_rate: float,
        momentum: float,
        num_epochs: int,
        batch_size: int,
    ):
        self.mlp_width = mlp_width
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.num_epochs = num_epochs
        self.batch_size = batch_size

    @property
    def pad_amount(self) -> int:
        return 0

    @property
    def model_dtype(self) -> torch.dtype:
        return torch.float32

    @property
    def one_hot_targets(self) -> bool:
        return False

    def build_model(self, train_images: torch.Tensor, device: torch.device) -> torch.nn.Module:
        input_size = train_images.shape[1] * train_images.shape[2] * train_images.shape[3]

        return MLP(input_dim=input_size, width=self.mlp_width).to(device)

    def train(
        self,
        train_images: torch.Tensor,
        train_targets: torch.Tensor,
        crop_size: int,
        net: torch.nn.Module,
        drop_last_batch: bool,
        generator: torch.Generator = None,
    ) -> torch.nn.Module:
        if crop_size != train_images.shape[2]:
            raise NotImplementedError("Crop size not supported")

        # JAX trainer effectively just iterates over samples, dropping the last dimension
        train_loader = torch.utils.data.DataLoader(
            torch.utils.data.TensorDataset(train_images, train_targets),
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=drop_last_batch,
        )

        net.train()
        optimizer = torch.optim.SGD(net.parameters(), lr=self.learning_rate, momentum=self.momentum)

        for _ in range(self.num_epochs):
            for batch_images, batch_targets in train_loader:
                optimizer.zero_grad()
                batch_predictions = net(batch_images)
                loss = torch.nn.functional.cross_entropy(batch_predictions, batch_targets, reduction="mean")
                loss.backward()
                optimizer.step()

        net.eval()
        return net

    def influence_modules(self) -> list[str]:
        return ["fc1", "fc2"]

    def representation_module(self) -> str:
        return "fc2"


class HLBTorchArchitecture(TorchArchitecture):
    def __init__(
        self,
        base_depth: int,
        num_epochs: float,
    ):
        self.base_depth = base_depth
        self.num_epochs = num_epochs

    @property
    def pad_amount(self) -> int:
        return influence.model_fast.hyp["net"]["pad_amount"]

    @property
    def model_dtype(self) -> torch.dtype:
        return torch.float16

    @property
    def one_hot_targets(self) -> bool:
        return True

    def build_model(self, train_images: torch.Tensor, device: torch.device) -> torch.nn.Module:
        model = influence.model_fast.make_net(train_images, device, base_depth=self.base_depth)
        return model

    def train(
        self,
        train_images: torch.Tensor,
        train_targets: torch.Tensor,
        crop_size: int,
        net: torch.nn.Module,
        drop_last_batch: bool,  # ignored
        generator: torch.Generator = None,
    ) -> torch.nn.Module:
        return influence.train_fast.train(
            train_images, train_targets, crop_size, net, num_epochs=self.num_epochs, generator=generator
        )

    def influence_modules(self) -> list[str]:
        """Return a list of module names that influence the model's output."""
        return [
            # "net_ema.net_dict.initial_block.whiten",
            # "net_ema.net_dict.conv_group_1",
            "net_ema.net_dict.conv_group_1.conv1",
            "net_ema.net_dict.conv_group_1.conv2",
            "net_ema.net_dict.conv_group_1.norm1",
            "net_ema.net_dict.conv_group_1.norm2",
            # "net_ema.net_dict.conv_group_1.activ",
            # "net_ema.net_dict.conv_group_2",
            "net_ema.net_dict.conv_group_2.conv1",
            "net_ema.net_dict.conv_group_2.conv2",
            "net_ema.net_dict.conv_group_2.norm1",
            "net_ema.net_dict.conv_group_2.norm2",
            # "net_ema.net_dict.conv_group_2.activ",
            # "net_ema.net_dict.conv_group_3",
            "net_ema.net_dict.conv_group_3.conv1",
            "net_ema.net_dict.conv_group_3.conv2",
            "net_ema.net_dict.conv_group_3.norm1",
            "net_ema.net_dict.conv_group_3.norm2",
            # "net_ema.net_dict.conv_group_3.activ",
            "net_ema.net_dict.linear",
        ]

    def representation_module(self) -> str:
        """Return the name of the module used for representation."""

        return "net_ema.net_dict.conv_group_2.conv2"
