from dataclasses import dataclass
from typing import override

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.rich import tqdm

from ..data import Dataset
from ..utils import get_device
from . import Predictor, PredictorConfig, register_predictor


@dataclass
class SimpleNetConfig(PredictorConfig):
    input_dim: int = 10
    output_dim: int = 1
    epochs: int = 100
    lr: float = 0.01


class SimpleNet(Predictor):
    def __init__(self, config: SimpleNetConfig):
        self.config = config
        self.device = get_device()

        self.model = nn.Sequential(
            nn.Linear(config.input_dim, config.output_dim), nn.Sigmoid()
        ).to(self.device)
        self.criterion = nn.MSELoss()
        self.optimizer = optim.SGD(self.model.parameters(), lr=config.lr)

    def train(self, X: np.ndarray, y: np.ndarray):
        X_tensor = torch.FloatTensor(X).to(self.device)
        y_tensor = torch.FloatTensor(y).to(self.device)
        tab = " " * 17
        for _ in tqdm(range(self.config.epochs), desc=tab + "Training SimpleNet"):
            self.optimizer.zero_grad()
            outputs = self.model(X_tensor)
            loss = self.criterion(outputs, y_tensor)
            loss.backward()
            self.optimizer.step()

    @override
    def fit(self, dataset: "Dataset") -> None:
        """Fit SimpleNet to dataset."""
        X, y = dataset.load_predictor()
        self.train(X, y)

    @override
    def predict(self, X: np.ndarray) -> np.ndarray:
        self.model.eval()
        with torch.no_grad():
            X_tensor = torch.FloatTensor(X).to(self.device)
            predictions = self.model(X_tensor)
        return predictions.cpu().numpy()


register_predictor("simple_net", SimpleNetConfig, SimpleNet)
