from abc import ABC, abstractmethod

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset

from src.solvers.solver import Solver
from src.utils.strings import *


class LODLSurrogate(ABC, torch.nn.Module):

    def __init__(self, y_true: torch.Tensor):

        super().__init__()

        self._y_true = y_true
        self._dim = y_true.shape[0]

    def sample(self, perturbation_sigma: float, perturbation_samples: int, x_true: np.ndarray,
               solver: Solver, problem_params: dict[str, np.ndarray]) -> tuple[torch.Tensor, torch.Tensor]:

        solver.freeze_calls_count()

        loc = np.zeros(self._dim)
        scale = np.ones(self._dim)
        size = (perturbation_samples, self._dim)
        y_perturbations = self._y_true + perturbation_sigma * np.random.normal(loc=loc, scale=scale, size=size)

        y_samples = torch.stack([self._y_true + noise for noise in y_perturbations]).numpy()

        costs_samples = []
        for y_sample in y_samples:
            solution, _ = solver.solve(x_true, y_sample, problem_params)
            cost = solver.compute_metrics(self._y_true.numpy(), solution, problem_params)[TOTAL_COST]
            costs_samples.append(cost)

        solver.unfreeze_calls_count()

        return torch.Tensor(y_samples), torch.Tensor(costs_samples)

    def train_surrogate(self, y_samples: torch.Tensor, costs_samples: torch.Tensor,
                        epochs: int = 15, batch_size: int = 128, lr: float = 1.0,
                        device: torch.device = torch.device("cpu"), verbose: bool = False) -> None:

        self.train(True)

        mse_loss = torch.nn.MSELoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=lr)

        dataset = TensorDataset(y_samples, costs_samples)
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

        for epoch in range(epochs):

            epoch_loss = 0.0

            for y, cost in loader:

                y = y.to(device)
                cost = cost.to(device)

                optimizer.zero_grad()

                prediction = self(y)
                loss = mse_loss(prediction, cost)
                loss = torch.mean(loss)

                loss.backward()

                optimizer.step()

                epoch_loss += loss.item()

            epoch_loss /= len(loader)

            if verbose:
                print("Epoch loss: {}".format(epoch_loss))

        self.eval()

    @abstractmethod
    def forward(self, y_hat: torch.Tensor) -> torch.Tensor:
        pass
