from abc import ABC, abstractmethod

import torch
import numpy as np
from matplotlib import pyplot as plt


class Baseline(torch.nn.Module, ABC):

    @abstractmethod
    def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        pass

    def plot(self) -> None:

        target = 5.0

        x = np.linspace(0, 10, 100)
        y = np.array([self(torch.Tensor([x_value]), target).numpy() for x_value in x])

        plt.plot(x, y)
        plt.show()


class NullBaseline(Baseline):

    def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> float:
        return 0.0


class EuclideanBaseline(Baseline):

    def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        return torch.sqrt(torch.sum((prediction - target) ** 2))


class QuadraticBaseline(Baseline):

    def forward(self, prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        return torch.sum((prediction - target) ** 2)
