from abc import abstractmethod

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from learning.criterion import Criterion
from learning.optimizer import Optimizer


class BaseModel(nn.Module):
    def __init__(
        self,
        optimizer_instance: Optimizer,
        criterion_instance: Criterion,
    ):
        super().__init__()
        self.optimizer_instance = optimizer_instance
        self.criterion_instance = criterion_instance
        self.criterion_instance.set_from_model(self)

    def apply_weight_initialization(self):
        weight_init = self.optimizer_instance.get_weight_init()
        if weight_init:
            if self.optimizer_instance.random_state:
                torch.manual_seed(
                    self.optimizer_instance.random_state.randint(0, 10000)
                )
            for layer in self.modules():
                if isinstance(layer, nn.Linear):
                    nn.init.zeros_(layer.bias)
                    if weight_init == "zeros":
                        nn.init.zeros_(layer.weight)
                    elif weight_init == "xavier":
                        nn.init.xavier_uniform_(layer.weight)
                    elif weight_init == "uniform":
                        nn.init.uniform_(layer.weight)
                    elif weight_init == "normal":
                        nn.init.normal_(layer.weight, mean=0, std=0.01)

    def fit(self, X_train, y_train, verbose=0):
        dataset = TensorDataset(X_train, y_train)
        if self.optimizer_instance.batch_size is not None:
            dataloader = DataLoader(
                dataset, batch_size=self.optimizer_instance.batch_size, shuffle=True
            )
        else:
            dataloader = DataLoader(dataset, batch_size=len(X_train), shuffle=False)

        previous_loss = float("inf")

        total_epochs = (
            self.optimizer_instance.epochs
            if self.optimizer_instance.epochs is not None
            else self.optimizer_instance.max_epochs
        )
        for epoch in range(total_epochs):
            self.train()
            running_loss = 0.0

            for batch_X, batch_y in dataloader:
                if isinstance(self.optimizer_instance.optimizer, optim.LBFGS):

                    def closure():
                        self.optimizer_instance.optimizer.zero_grad()
                        y_pred = self.forward(batch_X)
                        crit_name = self.criterion_instance.criterion_name
                        if crit_name == "CrossEntropy":
                            # Multiclass: logits [N, C], targets [N] (Long)
                            loss = self.criterion_instance.compute_loss(y_pred, batch_y.long().view(-1))
                        else:
                            loss = self.criterion_instance.compute_loss(y_pred.view(-1), batch_y.view(-1).float())
                        loss.backward()
                        return loss

                    loss = self.optimizer_instance.optimizer.step(closure)

                else:
                    self.optimizer_instance.optimizer.zero_grad()
                    y_pred = self.forward(batch_X).view(-1)
                    loss = self.criterion_instance.compute_loss(
                        y_pred, batch_y.view(-1).float()
                    )
                    loss.backward()
                    self.optimizer_instance.optimizer.step()

                running_loss += loss.item()

            avg_loss = running_loss / len(dataloader)

            if verbose > 0 and epoch % (self.optimizer_instance.max_epochs // 10) == 0:
                print(
                    f"Epoch [{epoch + 1}/{self.optimizer_instance.max_epochs}], Loss: {avg_loss:.4f}"
                )

            if abs(previous_loss - avg_loss) < self.optimizer_instance.tol:
                break

            previous_loss = avg_loss

    def predict(self, X_test):
        self.eval()
        with torch.no_grad():
            y_prob = self.forward(X_test).squeeze()
            return y_prob

    @abstractmethod
    def forward(self, x):
        pass


class BaseClassifier(BaseModel):
    def __init__(
        self,
        optimizer_instance: Optimizer,
        criterion_instance: Criterion,
        num_class: int,
    ):
        super().__init__(optimizer_instance, criterion_instance)
        self.is_classifier = True
        self.num_class = num_class


class BaseRegressor(BaseModel):
    def __init__(
        self,
        optimizer_instance: Optimizer,
        criterion_instance: Criterion,
    ):
        super().__init__(optimizer_instance, criterion_instance)
        self.is_classifier = False