import os
import pandas as pd
import torch

from matplotlib import pyplot as plt


class Benchmark:
    def __init__(self):
        folder_path = "./benchmark/"
        experiment_csv = "navier_stokes/navier_stokes.csv"
        self.df_benchmark = pd.read_csv(
            os.path.join(folder_path, experiment_csv), sep=";"
        )
        self.df_benchmark.columns = map(lambda x: x.strip(), self.df_benchmark.columns)

    def evaluate_network(self, net):
        x = torch.from_numpy(self.df_benchmark.x.to_numpy()).float()
        y = torch.from_numpy(self.df_benchmark.y.to_numpy()).float()
        x.requires_grad = True
        y.requires_grad = True
        ux, uy = net.velocity(x, y)
        target_ux = torch.from_numpy(self.df_benchmark.u.to_numpy()).float()
        target_uy = torch.from_numpy(self.df_benchmark.v.to_numpy()).float()
        errors = torch.stack([ux - target_ux, uy - target_uy], dim=1)
        return errors.pow(2).mean().sqrt().item()

    def plot_velocity(self, net):
        npx = self.df_benchmark.x.to_numpy()
        npy = self.df_benchmark.y.to_numpy()
        x = torch.from_numpy(npx).float()
        y = torch.from_numpy(npy).float()
        x.requires_grad = True
        y.requires_grad = True
        ux, uy = net.velocity(x, y)
        velocity = (ux + uy).pow(2).sqrt().detach().numpy()
        fig, ax = plt.subplots()
        ax.scatter(x=npx, y=npy, c=velocity)
        return fig
