# %%
import matplotlib.pyplot as plt 
from scipy.optimize import curve_fit
# import rdt_numerical_solver as rdt
from tqdm import tqdm 
# import closure_models as rdt_cm

import numpy as np 
import torch
import torch.nn as nn

import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

import psutil
import os 
from torch.utils.data import Dataset, DataLoader

plt.style.use('seaborn-v0_8-poster')
import pickle 

# %%
# remove outliers 

# %%
### Load the data 

structure_tensors = pickle.load(open('ICLR_sup_data.pkl', 'rb'))

# structure_tensors.keys()

batch_size_arr = [structure_tensors['R_array'][i].shape[1] for i in range(320)]


_R_array = np.concatenate(structure_tensors['R_array'], axis = -1) 
R_array = np.zeros([3,3, _R_array.shape[-1]])

R_array[0, 0, :] = _R_array[0, :]
R_array[0, 1, :] = _R_array[3, :]
R_array[0, 2, :] = _R_array[4, :]

R_array[1, 0, :] = _R_array[3, :]
R_array[1, 1, :] = _R_array[1, :]
R_array[1, 2, :] = _R_array[5, :]

R_array[2, 0, :] = _R_array[4, :]
R_array[2, 1, :] = _R_array[5, :]
R_array[2, 2, :] = _R_array[2, :]

R_array = R_array.transpose(2, 0, 1)

_D_array = np.concatenate(structure_tensors['D_array'], axis = -1)
D_array = np.zeros([3,3, _D_array.shape[-1]])

D_array[0, 0, :] = _D_array[0, :]
D_array[0, 1, :] = _D_array[3, :]
D_array[0, 2, :] = _D_array[4, :]

D_array[1, 0, :] = _D_array[3, :]
D_array[1, 1, :] = _D_array[1, :]
D_array[1, 2, :] = _D_array[5, :]

D_array[2, 0, :] = _D_array[4, :]
D_array[2, 1, :] = _D_array[5, :]
D_array[2, 2, :] = _D_array[2, :]

D_array = D_array.transpose(2, 0, 1)

Q_array = np.concatenate(structure_tensors['Q_array'], axis = -1).transpose(3, 0, 1, 2)
Qs_array = np.concatenate(structure_tensors['Qs_array'], axis = -1).transpose(3, 0, 1, 2)
M_array = np.concatenate(structure_tensors['M_array'], axis = -1).transpose(4, 0, 1, 2, 3)
q2_array = np.concatenate(structure_tensors['tke_array'], axis = -1)*2 

# %%
R_array.shape, D_array.shape, Q_array.shape, Qs_array.shape, M_array.shape, q2_array.shape

# %%
from tensors_tools import *

# %%
params = np.load('param_values_un.npy', allow_pickle=True)
dudy = mean_velocity_gradients_from_parameters(params)
dudy_repeat = [] 
for case in range(320):
    repeat_index = structure_tensors['R_array'][case].shape[1]
    _dudy_repeat = np.repeat(dudy[case][None, :, :], repeat_index, axis = 0)
    dudy_repeat.extend(_dudy_repeat)

print(np.array(dudy_repeat).shape)

# %%
Pi_array = rapid_pressure_strain_rate(M_array, dudy_repeat)
Pi_array.shape

# %%
from importlib import reload 
import tensors_tools
reload(tensors_tools)
from tensors_tools import * 

# %%
# harmonic data 

# h - implies harmonic 
# n - normalized by q2 array 

Rh_array = harmonic_proj(R_array, symmetric=True, order = 2)
Rhn_array = Rh_array / q2_array[:, None, None] 

Dh_array = harmonic_proj(D_array, symmetric=True, order = 2)
Dhn_array = Dh_array / q2_array[:, None, None] 

Qsn_array = Qs_array / q2_array[:, None, None, None]

Ms_array = get_M_star_from_M(M_array)
Msh_array = harmonic_proj(Ms_array, symmetric=True, order = 4)

Mshn_array = Msh_array / q2_array[:, None, None, None, None]

# %%
def harmonic_to_symmetric(Msh, trace_1, trace_2):
    I = np.eye(3)[None, :, :]
    Ms_reconst = Msh.copy() 
    Ms_reconst += 6/7*symmetrize(np.einsum('tij, tkl -> tijkl', I, trace_1)) 
    Ms_reconst -= 3/35*symmetrize(np.einsum('tij, tkl -> tijkl', I, I))*trace_2[:, None, None, None, None]
    return Ms_reconst 

def harmonic_to_symmetric_torch(Msh, trace_1, trace_2):
    bs = Msh.shape[0]
    I = torch.eye(3).unsqueeze(0).expand(bs, -1, -1)
    Ms_reconst = Msh.clone() 
    Ms_reconst += 6/7*symmetrize_torch(torch.einsum('tij, tkl -> tijkl', I, trace_1)) 
    Ms_reconst -= 3/35*symmetrize_torch(torch.einsum('tij, tkl -> tijkl', I, I))*trace_2[:, None, None, None, None]
    return Ms_reconst 

# %%
# verification 
np.allclose(np.einsum('tii->t', R_array), q2_array)
I = np.eye(3)[None, :, :]
R_reconst = Rh_array + 1/3*np.einsum('t, tij -> tij', q2_array, I)
assert np.allclose(R_array, R_reconst)

D_reconst = Dh_array + 1/3*np.einsum('t, tij -> tij', q2_array, I)
assert np.allclose(D_array, D_reconst)

assert np.allclose(Qs_array, Qsn_array * q2_array[:, None, None, None])

trace_2 = 1/3*q2_array
trace_1 = 1/6*(R_array + D_array) 

assert np.allclose(np.einsum('tiipq-> tpq', Ms_array), trace_1) 

assert np.allclose(np.einsum('tiipp-> t', Ms_array), trace_2) 

Ms_reconst = harmonic_to_symmetric(Msh_array, trace_1, trace_2)

assert np.allclose(Ms_reconst, Ms_array)

assert np.allclose(Msh_array, Mshn_array*q2_array[:, None, None, None, None])

# %%
class structure_tensors_dataset_hn(Dataset):
    # harmonic, normalized 
    def __len__(self): 
        return self.Rhn.shape[0] 
        
    def __init__(self, training_lim, training = True): 
        training_lim = int(training_lim)
        if training: 
            self.Rhn = torch.tensor(Rhn_array[: training_lim])
            self.Dhn = torch.tensor(Dhn_array[: training_lim])
            self.Qsn = torch.tensor(Qsn_array[: training_lim])
            self.Mshn = torch.tensor(Mshn_array[: training_lim])
        else: 
            self.Rhn = torch.tensor(Rhn_array[training_lim : ])
            self.Dhn = torch.tensor(Dhn_array[training_lim : ])
            self.Qsn = torch.tensor(Qsn_array[training_lim : ])
            self.Mshn = torch.tensor(Mshn_array[training_lim : ])

    def __getitem__(self, index): 
        return self.Rhn[index], self.Dhn[index], self.Qsn[index], self.Mshn[index]


# %%
def get_max_error(error, batch_size_arr): 
    bs = 0
    max_error = []  
    for i in batch_size_arr: 
        bs += i 
        max_error.append(error[bs -1].item())
    
    return max_error 

# %%
Msn_array = Ms_array / q2_array[:, None, None, None, None]  

# %%
num = 300000
popt_linear, pcov = curve_fit(M_star_model_b_y_linear,
                           [Rhn_array[:num], Dhn_array[:num]],
                           Msn_array[:num].flatten(), 
                           p0 = np.zeros(3)) 

# %%
### The linear model 

# %%
M_star_linear = M_star_model_b_y_linear([Rhn_array, Dhn_array], *popt_linear, flatten=False)*q2_array[:, None, None, None, None]
M_linear = M_decomposition(M_star_linear, Qs_array, R_array, D_array)
Pi_linear = rapid_pressure_strain_rate(M_linear, dudy_repeat)
Pi_error_linear = normalized_error_T(Pi_linear, Pi_array)

Pi_error_linear_max = get_max_error(Pi_error_linear, batch_size_arr)

Pi_error_linear_max.sort()
plt.figure(figsize = (4,4))
plt.semilogy(Pi_error_linear_max[::-1], label = "linear model")
plt.grid()
plt.xlabel("Index")
plt.ylabel("Normalized Error")
plt.legend()
plt.show()

# %%
Qs_harm_test = harmonic_proj_torch(torch.tensor(Qs_array), symmetric=True, order=3)
torch.allclose(Qs_harm_test, torch.tensor(Qs_array))

# %%
Ms_array_torch = torch.tensor(Ms_array)
Msh_array_torch = torch.tensor(Msh_array)

Msh_array_from_iclr_tools = harmonic_proj_torch(Ms_array_torch, symmetric=True, order = 4)

np.allclose(Msh_array_torch, Msh_array_from_iclr_tools)

# %%
M_array_torch = torch.tensor(M_array)

# %%
# learning a harmonic model 

class turbulence_model(nn.Module): 

    def __init__(self,
                  input_size = 12,
                  output_size = 15, 
                    num_layers = 1,
                      layer_size = 10,
                        max_degree = 2): 
        super(turbulence_model, self).__init__()
        self.num_layers = num_layers
        self.layer_size = layer_size
        self.input_size = input_size 
        self.output_size = output_size 
        self.max_degree = max_degree 
        layers = []

        in_size = self.input_size 

        for _ in range(num_layers):
            layers.append(nn.Linear(in_size, layer_size))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(0.1))
            in_size = layer_size

        layers.append(nn.Linear(layer_size, self.output_size))
        self.network = nn.Sequential(*layers)

    def evaluate_invariants(self, x, max_degree = 2):
        Rhn, Dhn, Qsn = x 
        if max_degree == 2: 
            return degree_2_invariants_torch(Rhn, Dhn, Qsn)
        if max_degree == 3:
            return degree_3_invariants_torch(Rhn, Dhn, Qsn) 
    
    def evaluate_tensor_basis_model(self, x, coeffs, max_degree = 2):
        if max_degree == 2:
            return degree_2_tensor_basis_model_torch(x, coeffs)
        if max_degree == 3:
            return degree_3_tensor_basis_model_torch(x, coeffs)
        
    def forward(self, x):

        invariants = self.evaluate_invariants(x, max_degree = self.max_degree)

        coeffs = self.network(invariants) 

        Mshn_predicted = self.evaluate_tensor_basis_model(x, coeffs, max_degree = self.max_degree)

        return Mshn_predicted 

# %%
### The Degree -1 Modle (no terms)

# %%
### Degree - 1 model 

Msh_pred = np.zeros_like(Msh_array) 

trace_1 = 1/6*(R_array + D_array) 
trace_2 = 1/3*q2_array

Ms_pred = harmonic_to_symmetric(Msh_pred, trace_1, trace_2)

M_pred = M_decomposition(Ms_pred, Qs_array, R_array, D_array)
Pi_pred = rapid_pressure_strain_rate(M_pred, dudy_repeat)

Pi_error = normalized_error_T(Pi_pred, Pi_array)

Pi_error_degree1_max = get_max_error(Pi_error, batch_size_arr)

Pi_error_liner_max = get_max_error(Pi_error_linear, batch_size_arr)

Pi_error_linear_max.sort() 

Pi_error_degree1_max.sort() 
plt.figure(figsize = (4,4))    
plt.semilogy(Pi_error_degree1_max[::-1], label = "Degree 1 NL")
plt.semilogy(Pi_error_linear_max[::-1], '--', label = "Linear")
plt.grid() 
plt.legend() 
plt.ylabel(r"Normalized Error $\pi$")
plt.xlabel("RDT cases")    

# %%
# lower batch size 
# higher number of layer width 

# learning a harmonic model 

class turbulence_model(nn.Module): 

    def __init__(self,
                  input_size = 12,
                  output_size = 15, 
                    num_layers = 1,
                      layer_size = 10,
                        max_degree = 2): 
        super(turbulence_model, self).__init__()
        self.num_layers = num_layers
        self.layer_size = layer_size
        self.input_size = input_size 
        self.output_size = output_size 
        self.max_degree = max_degree 
        layers = []

        in_size = self.input_size 

        for _ in range(num_layers):
            layers.append(nn.Linear(in_size, layer_size))
            layers.append(nn.LeakyReLU(negative_slope=0.01))
            layers.append(nn.Dropout(0.1))
            in_size = layer_size

        layers.append(nn.Linear(layer_size, self.output_size))
        self.network = nn.Sequential(*layers)

    def evaluate_invariants(self, x, max_degree = 2):
        Rhn, Dhn, Qsn = x 
        if max_degree == 2: 
            return degree_2_invariants_torch(Rhn, Dhn, Qsn)
        if max_degree == 3:
            return degree_3_invariants_torch(Rhn, Dhn, Qsn) 
    
    def evaluate_tensor_basis_model(self, x, coeffs, max_degree = 2):
        if max_degree == 2:
            return degree_2_tensor_basis_model_torch(x, coeffs)
        if max_degree == 3:
            return degree_3_tensor_basis_model_torch(x, coeffs)
        
    def forward(self, x):

        invariants = self.evaluate_invariants(x, max_degree = self.max_degree)

        coeffs = self.network(invariants) 

        Mshn_predicted = self.evaluate_tensor_basis_model(x, coeffs, max_degree = self.max_degree)

        return Mshn_predicted 

# %%
### The Degree 2 Model 

# %%
for n_layers in [2]:
    for l_width in [128]:
        for bs in [32]:
            
            print(f"n_layers = {n_layers}, l_width = {l_width}, bs = {bs}")
            turb_dataset_train_irrep = structure_tensors_dataset_hn(len(R_array)*0.8, training = True)
            turb_dataloader_train_irrep = DataLoader(dataset=turb_dataset_train_irrep,
                                            batch_size = bs,
                                            shuffle = 1 )

            Rhn_val = torch.tensor(Rhn_array[int(len(R_array)*0.8) : ])
            Dhn_val = torch.tensor(Dhn_array[int(len(R_array)*0.8) : ])
            Qsn_val = torch.tensor(Qsn_array[int(len(R_array)*0.8) : ])
            Mshn_val = torch.tensor(Mshn_array[int(len(R_array)*0.8) : ])

            total_steps = 100

            nn_model = turbulence_model(input_size = 11,
                                        output_size=27,
                                        max_degree=3, 
                                        num_layers=n_layers,
                                        layer_size=l_width)

            print(f"Number of parameters in the model: {sum(p.numel() for p in nn_model.parameters())}")

            loss_fn = nn.MSELoss(reduction='mean') 

            optimizer = torch.optim.AdamW(nn_model.parameters(), lr=1e-2, weight_decay=1e-4)
            # optimizer = torch.optim.Adam(nn_model.parameters(), lr=1e-2)

            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=20)


            train_loss_arr = [] 
            val_loss_arr = [] 

            verbose = 1
            verbose_freq = 10

            patience = 10

            best_val_loss = float('inf')
            epochs_no_improve = 0

            for epoch in tqdm(range(total_steps)):
                # training loop
                nn_model.train()
                train_loss_per_epoch = 0
                num = 0
                for Rhn, Dhn, Qsn, Mshn in turb_dataloader_train_irrep:
                    optimizer.zero_grad()
                    Mshn_pred = nn_model([Rhn.float(), Dhn.float(), Qsn.float()]).float()
                    loss = torch.mean(loss_fn(Mshn_pred, Mshn.float()), axis = 0).sum()
                    loss.backward()
                    optimizer.step()
                    train_loss_per_epoch += loss.item()
                    num += 1

                train_loss = train_loss_per_epoch / num
                scheduler.step(train_loss)
                train_loss_arr.append(train_loss)

                nn_model.eval()
                Mshn_pred_val = nn_model([Rhn_val.float(), Dhn_val.float(), Qsn_val.float()])
                val_loss = loss_fn(Mshn_pred_val, Mshn_val.float())
                val_loss = torch.mean(val_loss, axis=0).sum().item()
                val_loss_arr.append(val_loss)

                # Early stopping logic
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    epochs_no_improve = 0
                    # Optionally save the best model
                    torch.save(nn_model.state_dict(), "best_model.pt")
                else:
                    epochs_no_improve += 1
                    if epochs_no_improve >= patience:
                        print(f"Early stopping at epoch {epoch}")
                        break

                if verbose and epoch % verbose_freq == 0:
                    print(f'Epoch:  {epoch},\n Training loss: {train_loss_arr[-1]}')
                    print(f' Validation loss: {val_loss_arr[-1]}')

            print(f'Epoch:  {epoch},\n Training loss: {train_loss_arr[-1]}')
            print(f' Validation loss: {val_loss_arr[-1]}')

            plt.figure(figsize = (5,5))
            plt.semilogy(train_loss_arr, label = "train")
            plt.semilogy(val_loss_arr, label = "val")
            plt.grid()
            plt.ylabel(r'MSE on $\pi$')
            plt.xlabel('Epoch')
            plt.legend()
            plt.title(f"n_layers = {n_layers}, l_width = {l_width}, bs = {bs}")

            Mshn_pred = nn_model([torch.from_numpy(Rhn_array).float(),
                torch.from_numpy(Dhn_array).float(),
                torch.from_numpy(Qsn_array).float()])

            Mshn_pred = Mshn_pred.detach().numpy()

            Msh_pred = Mshn_pred*q2_array[:, None, None, None, None]

            Ms_pred = harmonic_to_symmetric(Msh_pred, trace_1, trace_2)

            M_pred = M_decomposition(Ms_pred, Qs_array, R_array, D_array)
            Pi_pred = rapid_pressure_strain_rate(M_pred, dudy_repeat)
            Pi_error_degree_3 = normalized_error_T(Pi_pred, Pi_array)

            Pi_error_degree_3_max = get_max_error(Pi_error_degree_3, batch_size_arr)

            Pi_error_degree_3_max.sort() 
            plt.figure(figsize = (4,4))    
            plt.semilogy(Pi_error_degree_3_max[::-1], label = "Degree 3 NL")
            plt.semilogy(Pi_error_linear_max[::-1], '-', label = "Linear")
            plt.grid() 
            plt.legend() 
            plt.ylabel(r"Normalized Error $\pi$")
            plt.xlabel("RDT cases")    
            plt.title(f"n_layers = {n_layers}, l_width = {l_width}, bs = {bs}")

# %%
### The Degree - 3 Model 

# %%
for n_layers in [2]:
    for l_width in [128]:
        for bs in [32]:
            
            print(f"n_layers = {n_layers}, l_width = {l_width}, bs = {bs}")
            turb_dataset_train_irrep = structure_tensors_dataset_hn(len(R_array)*0.8, training = True)
            turb_dataloader_train_irrep = DataLoader(dataset=turb_dataset_train_irrep,
                                            batch_size = bs,
                                            shuffle = 1 )

            Rhn_val = torch.tensor(Rhn_array[int(len(R_array)*0.8) : ])
            Dhn_val = torch.tensor(Dhn_array[int(len(R_array)*0.8) : ])
            Qsn_val = torch.tensor(Qsn_array[int(len(R_array)*0.8) : ])
            Mshn_val = torch.tensor(Mshn_array[int(len(R_array)*0.8) : ])

            total_steps = 100

            nn_model = turbulence_model(input_size = 4,
                                        output_size=6,
                                        max_degree=2, 
                                        num_layers=n_layers,
                                        layer_size=l_width)

            print(f"Number of parameters in the model: {sum(p.numel() for p in nn_model.parameters())}")

            loss_fn = nn.MSELoss(reduction='mean') 

            optimizer = torch.optim.AdamW(nn_model.parameters(), lr=1e-2, weight_decay=1e-4)
            # optimizer = torch.optim.Adam(nn_model.parameters(), lr=1e-2)

            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=20)


            train_loss_arr = [] 
            val_loss_arr = [] 

            verbose = 1
            verbose_freq = 10

            patience = 10

            best_val_loss = float('inf')
            epochs_no_improve = 0

            for epoch in tqdm(range(total_steps)):
                # training loop
                nn_model.train()
                train_loss_per_epoch = 0
                num = 0
                for Rhn, Dhn, Qsn, Mshn in turb_dataloader_train_irrep:
                    optimizer.zero_grad()
                    Mshn_pred = nn_model([Rhn.float(), Dhn.float(), Qsn.float()]).float()
                    loss = torch.mean(loss_fn(Mshn_pred, Mshn.float()), axis = 0).sum()
                    loss.backward()
                    optimizer.step()
                    train_loss_per_epoch += loss.item()
                    num += 1

                train_loss = train_loss_per_epoch / num
                scheduler.step(train_loss)
                train_loss_arr.append(train_loss)

                nn_model.eval()
                Mshn_pred_val = nn_model([Rhn_val.float(), Dhn_val.float(), Qsn_val.float()])
                val_loss = loss_fn(Mshn_pred_val, Mshn_val.float())
                val_loss = torch.mean(val_loss, axis=0).sum().item()
                val_loss_arr.append(val_loss)

                # Early stopping logic
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    epochs_no_improve = 0
                    # Optionally save the best model
                    torch.save(nn_model.state_dict(), "best_model.pt")
                else:
                    epochs_no_improve += 1
                    if epochs_no_improve >= patience:
                        print(f"Early stopping at epoch {epoch}")
                        break

                if verbose and epoch % verbose_freq == 0:
                    print(f'Epoch:  {epoch},\n Training loss: {train_loss_arr[-1]}')
                    print(f' Validation loss: {val_loss_arr[-1]}')

            print(f'Epoch:  {epoch},\n Training loss: {train_loss_arr[-1]}')
            print(f' Validation loss: {val_loss_arr[-1]}')

            plt.figure(figsize = (5,5))
            plt.semilogy(train_loss_arr, label = "train")
            plt.semilogy(val_loss_arr, label = "val")
            plt.grid()
            plt.ylabel(r'MSE on $\pi$')
            plt.xlabel('Epoch')
            plt.legend()
            plt.title(f"n_layers = {n_layers}, l_width = {l_width}, bs = {bs}")

            Mshn_pred = nn_model([torch.from_numpy(Rhn_array).float(),
                torch.from_numpy(Dhn_array).float(),
                torch.from_numpy(Qsn_array).float()])

            Mshn_pred = Mshn_pred.detach().numpy()

            Msh_pred = Mshn_pred*q2_array[:, None, None, None, None]

            Ms_pred = harmonic_to_symmetric(Msh_pred, trace_1, trace_2)

            M_pred = M_decomposition(Ms_pred, Qs_array, R_array, D_array)
            Pi_pred = rapid_pressure_strain_rate(M_pred, dudy_repeat)
            Pi_error_degree_2 = normalized_error_T(Pi_pred, Pi_array)

            Pi_error_degree_2_max = get_max_error(Pi_error_degree_2, batch_size_arr)

            Pi_error_degree_2_max.sort() 
            plt.figure(figsize = (4,4))    
            plt.semilogy(Pi_error_degree_2_max[::-1], label = "Degree 2 NL")
            plt.semilogy(Pi_error_linear_max[::-1], '-', label = "Linear")
            plt.grid() 
            plt.legend() 
            plt.ylabel(r"Normalized Error $\pi$")
            plt.xlabel("RDT cases")    
            plt.title(f"n_layers = {n_layers}, l_width = {l_width}, bs = {bs}")

# %%
# LRR 
c = 0.4 
C2 = (c + 8)/11 
C3 = (8*c -2)/11 

bij_array = (R_array/q2_array[:, None, None] - (1/3)*np.eye(3)[None, :, :])*2 

Pi_LRR = rapid_term_GLM(bij_array, np.array(dudy_repeat), q2_array, C2, C3)
Pi_LRR_error = normalized_error_T(Pi_LRR, Pi_array)
Pi_LRR_error_max = get_max_error(Pi_LRR_error, batch_size_arr)
Pi_LRR_error_max.sort()


Pi_IP = rapid_term_GLM(bij_array, np.array(dudy_repeat), q2_array, 3/5, 0)
Pi_IP_error = normalized_error_T(Pi_IP, Pi_array)
Pi_IP_error_max = get_max_error(Pi_IP_error, batch_size_arr)
Pi_IP_error_max.sort()

plt.figure(figsize = (4,4))
plt.semilogy(Pi_LRR_error_max[::-1], label = "LRR")
plt.semilogy(Pi_IP_error_max[::-1], label = "IP")
plt.semilogy(Pi_error_linear_max[::-1], '-', label = "Linear")
plt.semilogy(Pi_error_degree_2_max_og[::-1], '-', label = "deg 2 nl")
plt.semilogy(Pi_error_degree_2_max[::-1], '-', label = "deg 3 nl")
plt.grid() 
plt.ylabel(r"Normalized Error $\pi$")
plt.xlabel("RDT cases")    
plt.legend()







