import os
import pandas as pd
import torch

from matplotlib import pyplot as plt


class Benchmark:
    def __init__(self):
        folder_path = "./benchmark/"
        experiment_csv = "burgers/burgers.csv"
        self.df_benchmark = pd.read_csv(os.path.join(folder_path, experiment_csv))

        self.len_x = len(self.df_benchmark["x"].unique())
        self.eval_x = self.df_benchmark["x"].unique()
        self.k = 1e-1

    def __call__(self, func, val):
        return self.df_benchmark[f"{func}{val}"].values

    def plot_solution(self, net):
        zero = "0"
        half = "05"
        one = "1"
        u_zero_true = self.__call__("u", zero).reshape(self.len_x)
        u_half_true = self.__call__("u", half).reshape(self.len_x)
        u_one_true = self.__call__("u", one).reshape(self.len_x)

        x = torch.tensor(self.df_benchmark[["x"]].values, dtype=torch.float).reshape(
            -1,
        )
        zero = torch.zeros_like(x)
        zero = torch.stack([x, zero], dim=1)
        half = 5e-2 * torch.ones_like(x)
        half = torch.stack([x, half], dim=1)
        one = 1e-1 * torch.ones_like(x)
        one = torch.stack([x, one], dim=1)
        x.requires_grad = True
        zero.requires_grad = True
        half.requires_grad = True
        one.requires_grad = True
        u_zero = net(zero).reshape(self.len_x).detach().numpy()
        u_half = net(half).reshape(self.len_x).detach().numpy()
        u_one = net(one).reshape(self.len_x).detach().numpy()

        # Create the plot
        x = x.detach().numpy()
        plt.subplots(figsize=(10, 10))

        # Plot each function
        plt.plot(x, u_zero, label="t=0", color="#92dce5")
        plt.plot(x, u_zero_true, "#92dce5", linestyle="dashed", label="t=0 (target)")
        plt.plot(x, u_half, label="t=0.05", color="#867bfa")
        plt.plot(
            x, u_half_true, label="t=0.05 (target)", color="#867bfa", linestyle="dashed"
        )
        plt.plot(x, u_one, label="t=0.1", color="#ff986e")
        plt.plot(
            x, u_one_true, label="t=0.1 (target)", color="#ff986e", linestyle="dashed"
        )
        # Add a legend
        plt.legend()
        plt.show()
        # return fig

    def evaluate_network(self, net):
        x = torch.tensor(self.df_benchmark[["x"]].values, dtype=torch.float)
        one = 1e-1 * torch.ones_like(x)
        one = torch.stack([x, one], dim=1).squeeze()
        one.requires_grad = True

        preds = net(one).squeeze()
        exact = torch.tensor(self.__call__("u", "1"))
        return (exact - preds).pow(2).mean().sqrt().item()
