import os
import pandas as pd
import torch

from matplotlib import pyplot as plt


class Benchmark:
    def __init__(self):
        self.D = 5e-3

    def __call__(self, func, val):
        return self.df_benchmark[f"{func}{val}"].values
    
    def get_csv(self):
        folder_path = "./benchmark/"
        experiment_csv = "heat1/heat1.csv"
        self.df_benchmark = pd.read_csv(os.path.join(folder_path, experiment_csv))
        self.len_x = len(self.df_benchmark["x"].unique())

    def get_solutions(self, net=None):
        self.get_csv()

        if net is None:
            zero = "0"
            half = "05"
            one = "1"
            u_zero = self.__call__("u", zero).reshape(self.len_x, self.len_x)
            u_half = self.__call__("u", half).reshape(self.len_x, self.len_x)
            u_one = self.__call__("u", one).reshape(self.len_x, self.len_x)
        else:
            x = torch.tensor(self.df_benchmark[["x"]].values, dtype=torch.float)
            y = torch.tensor(self.df_benchmark[["y"]].values, dtype=torch.float)
            zero = torch.zeros_like(x)
            zero = torch.stack([x, y, zero], dim=1).squeeze()
            zero.requires_grad = True
            half = 0.5 + torch.zeros_like(x)
            half = torch.stack([x, y, half], dim=1).squeeze()
            half.requires_grad = True
            one = torch.ones_like(x)
            one = torch.stack([x, y, one], dim=1).squeeze()
            one.requires_grad = True
            u_zero = net(zero).reshape(self.len_x, self.len_x).detach().numpy()
            u_half = net(half).reshape(self.len_x, self.len_x).detach().numpy()
            u_one = net(one).reshape(self.len_x, self.len_x).detach().numpy()

        return u_zero, u_half, u_one

    def plot_solution(self, net=None, error=False):
        u_zero, u_half, u_one = self.get_solutions(net=net)
        label = ""

        if error:
            u_zero_true, u_half_true, u_one_true = self.get_solutions(net=None)
            u_zero = (u_zero - u_zero_true)**2
            u_half = (u_half - u_half_true)**2
            u_one = (u_one - u_one_true)**2
            label = "MSE "

        fig, ax = plt.subplots(1, 3)

        im0 = ax[0].imshow(u_zero)
        ax[0].set_title(f"{label}u(t=0)")
        plt.colorbar(im0, ax=ax[0], fraction=0.046, pad=0.04)

        im1 = ax[1].imshow(u_half)
        ax[1].set_title(f"{label}u(t=0.5)")
        plt.colorbar(im1, ax=ax[1], fraction=0.046, pad=0.04)

        im2 = ax[2].imshow(u_one)
        ax[2].set_title(f"{label}u(t=1)")
        plt.colorbar(im2, ax=ax[2], fraction=0.046, pad=0.04)

        fig.tight_layout()
        for axis in ax.flat:
            axis.set_axis_off()
        return fig

    def evaluate_network(self, net):
        self.get_csv()
        x = torch.tensor(self.df_benchmark[["x"]].values, dtype=torch.float)
        y = torch.tensor(self.df_benchmark[["y"]].values, dtype=torch.float)
        one = torch.ones_like(x)
        one = torch.stack([x, y, one], dim=1).squeeze()
        one.requires_grad = True

        exact = torch.tensor(self.__call__("u", "1"))
        preds = net(one).squeeze()
        return (exact - preds).pow(2).mean().sqrt().item()
