import torch
import argparse
from torch.utils.data import DataLoader
import numpy as np
import random
import os
seed = 30
from itertools import cycle

from model_multi import AllenNet
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
gen = torch.Generator()
gen.manual_seed(seed)


parser = argparse.ArgumentParser()
parser.add_argument('--bc_weight', default=1., type=float, help='Weight for the BC loss')
parser.add_argument('--sys_weight', default=1., type=float, help='Weight for the rho loss')
parser.add_argument('--pde_weight', default=1., type=float, help='Weight for the PDE loss')
parser.add_argument('--lr_init', default=1e-4, type=float, help='Starting learning rate')
parser.add_argument('--device', default='cuda', type=str, help='Device to use')
parser.add_argument('--name', default='multi', type=str, help='multieriment name')
parser.add_argument('--train_steps', default=100000, type=int, help='Number of training steps')
parser.add_argument('--epochs', default=100, type=int, help='Number of epochs')
parser.add_argument('--mode', default='Derivative', type=str, help='Mode: -1 for PINN learning, 0 for DERL, 1 for output learning')
parser.add_argument('--layers', default=4, type=int, help='Number of layers in the network')
parser.add_argument('--units', default=50, type=int, help='Number of units per layer in the network')

args = parser.parse_args()
bc_weight = args.bc_weight  
device = args.device
name = args.name
train_steps = args.train_steps
epochs = args.epochs
layers = args.layers
units = args.units
lr_init = args.lr_init
mode = args.mode
sys_weight = args.sys_weight
pde_weight = args.pde_weight
title_mode = mode

if not os.path.exists(f'multi'):
    os.makedirs(f'multi')

activation = torch.nn.Tanh()

model_derl = AllenNet(
    in_dim=4,
    bc_weight=bc_weight,
    pde_weight=pde_weight,
    sys_weight=sys_weight,
    hidden_units=[units for _ in range(layers)],
    lr_init=lr_init,
    device=device,
    activation=activation,    
    last_activation=False,
).to(device)
model_derl.load_state_dict(torch.load(f'multi/saved_models/allen_multi_netDerivative'))
model_derl.eval()

model_outl = AllenNet(
    in_dim=4,
    bc_weight=bc_weight,
    pde_weight=pde_weight,
    sys_weight=sys_weight,
    hidden_units=[units for _ in range(layers)],
    lr_init=lr_init,
    device=device,
    activation=activation,    
    last_activation=False,
).to(device)
model_outl.load_state_dict(torch.load(f'multi/saved_models/allen_multi_netOutput'))
model_outl.eval()

model_sob = AllenNet(
    in_dim=4,
    bc_weight=bc_weight,
    pde_weight=pde_weight,
    sys_weight=sys_weight,
    hidden_units=[units for _ in range(layers)],
    lr_init=lr_init,
    device=device,
    activation=activation,    
    last_activation=False,
).to(device)
model_sob.load_state_dict(torch.load(f'multi/saved_models/allen_multi_netSobolev'))
model_sob.eval()

model_pinn = AllenNet(
    in_dim=4,
    bc_weight=bc_weight,
    pde_weight=pde_weight,
    sys_weight=sys_weight,
    hidden_units=[units for _ in range(layers)],
    lr_init=lr_init,
    device=device,
    activation=activation,    
    last_activation=False,
).to(device)
model_pinn.load_state_dict(torch.load(f'multi/saved_models/allen_multi_netPINN'))
model_pinn.eval()

model_pinnout = AllenNet(
    in_dim=4,
    bc_weight=bc_weight,
    pde_weight=pde_weight,
    sys_weight=sys_weight,
    hidden_units=[units for _ in range(layers)],
    lr_init=lr_init,
    device=device,
    activation=activation,    
    last_activation=False,
).to(device)
model_pinnout.load_state_dict(torch.load(f'multi/saved_models/allen_multi_netPINN+Output'))
model_pinnout.eval()

if not os.path.exists(f'multi/plotsjoint'):
    os.makedirs(f'multi/plotsjoint')
    
if not os.path.exists(f'multi/plotsjoint/val'):
    os.makedirs(f'multi/plotsjoint/val')
    
if not os.path.exists(f'multi/plotsjoint/train'):
    os.makedirs(f'multi/plotsjoint/train')
    
if not os.path.exists(f'multi/plotsjoint/val_nopinn'):
    os.makedirs(f'multi/plotsjoint/val_nopinn')

if not os.path.exists(f'multi/plotsjoint/train_nopinn'):
    os.makedirs(f'multi/plotsjoint/train_nopinn')
    
# load allen_pts and make the grid
allen_pts = np.load('allen_pts.npy')
allen_x = np.load('allen_x.npy')
allen_y = np.load('allen_y.npy')
x, y = np.meshgrid(allen_x, allen_y)
allen_pts = np.column_stack((x.reshape(-1), y.reshape(-1)))
dx = 0.01
# Import the valing lambdas
lam_train = np.load('allen_lam_train_multi.npy')
lam_val = np.load('allen_lam_val_multi.npy')

num_val = len(lam_val)

from generate_multi import allen_cahn_true_multi, allen_cahn_forcing_multi
from plotting_utils import error_plotting, comparison_plotting

def evaluate_models_on_val_data(models, model_names, lam_list, allen_pts, device, x, y, lam_type='val'):
    total_errors = {name: 0 for name in model_names}
    total_forcing_errors = {name: 0 for name in model_names}
    num_val = len(lam_val)

    for lam in lam_list:
        print(f'lambda: {lam}')
        lam_in = np.tile(lam, (allen_pts.shape[0], 1))
        x_in = np.column_stack((allen_pts, lam_in))
        x_in = torch.tensor(x_in, dtype=torch.float32).to(device)

        u_preds = {}
        for model, name in zip(models, model_names):
            u_pred = model(x_in).reshape(-1)
            u_preds[name] = u_pred.cpu().detach().numpy()

        u_true = allen_cahn_true_multi(allen_pts, lam)
        f_true = allen_cahn_forcing_multi(allen_pts, lam)
        print('Calculating error')

        for name in model_names:
            error = np.linalg.norm(u_true - u_preds[name], 2) / np.linalg.norm(u_true, 2)
            total_errors[name] += error

            f_pred = models[model_names.index(name)].evaluate_forcing(x_in)
            f_pred = f_pred.cpu().detach().numpy()
            forcing_error = np.linalg.norm(f_true - f_pred, 2) / np.linalg.norm(f_true, 2)
            total_forcing_errors[name] += forcing_error

    avg_errors = {name: total_errors[name] / num_val for name in model_names}
    avg_forcing_errors = {name: total_forcing_errors[name] / num_val for name in model_names}

    with open(f'multi/plotsjoint/{lam_type}_results.txt', 'w') as f:
        for name in model_names:
            f.write(f'Average Error {name}: {avg_errors[name]}\n')
        f.write('\n')
        for name in model_names:
            f.write(f'Average Error {name} Forcing: {avg_forcing_errors[name]}\n')
        f.write('\n')

    for lam in lam_list:
        print(f'lambda: {lam}')
        lam_in = np.tile(lam, (allen_pts.shape[0], 1))
        x_in = np.column_stack((allen_pts, lam_in))
        x_in = torch.tensor(x_in, dtype=torch.float32).to(device)

        u_preds = {}
        for model, name in zip(models, model_names):
            u_pred = model(x_in).reshape(-1)
            u_preds[name] = u_pred.cpu().detach().numpy()

        u_true = allen_cahn_true_multi(allen_pts, lam)
        f_true = allen_cahn_forcing_multi(allen_pts, lam)
        print('Calculating error')

        with open(f'multi/plotsjoint/{lam_type}_results.txt', 'a') as f:
            f.write(f'lambda: {lam}\n')
            for name in model_names:
                error = np.linalg.norm(u_true - u_preds[name], 2) / np.linalg.norm(u_true, 2)
                f.write(f'Error {name}: {error}\n')

                f_pred = models[model_names.index(name)].evaluate_forcing(x_in)
                f_pred = f_pred.cpu().detach().numpy()
                forcing_error = np.linalg.norm(f_true - f_pred, 2) / np.linalg.norm(f_true, 2)
                f.write(f'Error {name} Forcing: {forcing_error}\n')
            f.write('\n')

        cap = 100.
        error_plots = {name: np.clip(np.abs(u_true - u_preds[name]).reshape(x.shape), 0, cap) for name in model_names}
        forcing_error_plots = {name: np.clip(np.abs(f_true - models[model_names.index(name)].evaluate_forcing(x_in).cpu().detach().numpy()).reshape(x.shape), 0, cap) for name in model_names}

        error_plotting([error_plots[name] for name in model_names], x, y, model_names, path=f'multi/plotsjoint/{lam_type}/error_lambda_{lam}.pdf')
        error_plotting([forcing_error_plots[name] for name in model_names], x, y, model_names, path=f'multi/plotsjoint/val/forcing_error_lambda_{lam}.pdf')

        comparison_plot = [error_plots[name] - error_plots['DERL'] for name in model_names if name != 'DERL']
        comparison_plotting(comparison_plot, x, y, [name for name in model_names if name != 'DERL'], path=f'multi/plotsjoint/{lam_type}/comparison_error_lambda_{lam}.pdf')

        comparison_plot = [forcing_error_plots[name] - forcing_error_plots['DERL'] for name in model_names if name != 'DERL']
        comparison_plotting(comparison_plot, x, y, [name for name in model_names if name != 'DERL'], path=f'multi/plotsjoint/{lam_type}/comparison_forcing_error_lambda_{lam}.pdf')

model_list = [model_derl, model_outl, model_pinnout, model_sob, model_pinn]
model_names = ['DERL', 'OUTL', 'OUTL+PINN', 'SOB', 'PINN']

model_list_nopinn = [model_derl, model_outl, model_pinnout, model_sob]
model_names_nopinn = ['DERL', 'OUTL', 'OUTL+PINN', 'SOB']

evaluate_models_on_val_data(
    model_list, model_names,
    lam_val, allen_pts, device, x, y, 'val')

evaluate_models_on_val_data(
    model_list, model_names,
    lam_train, allen_pts, device, x, y, 'train')

evaluate_models_on_val_data(
    model_list_nopinn, model_names_nopinn,
    lam_val, allen_pts, device, x, y, 'val_nopinn')

evaluate_models_on_val_data(
    model_list_nopinn, model_names_nopinn,
    lam_train, allen_pts, device, x, y, 'train_nopinn')

from generate_multi import allen_cahn_true_multi, allen_cahn_forcing_multi
from plotting_utils import error_plotting, comparison_plotting
num_lam = 2

test_lam = np.random.random((5,num_lam))
test_lam = np.load('allen_lam_test_multi.npy')
if not os.path.exists(f'multi/plotsjoint/test'):
    os.makedirs(f'multi/plotsjoint/test')
    
    
if not os.path.exists(f'multi/plotsjoint/test_nopinn'):
    os.makedirs(f'multi/plotsjoint/test_nopinn')

# Calculate the average error among the test lambdas

evaluate_models_on_val_data(
    model_list, model_names,
    test_lam, allen_pts, device, x, y, 'test')

evaluate_models_on_val_data(
    model_list_nopinn, model_names_nopinn,
    test_lam, allen_pts, device, x, y, 'test_nopinn')

    
# Generate a grid of lambda_1, lambda_2 in [0,1]x[0,1]
dl = 0.05
lam1 = np.arange(dl, 1+dl, dl)
lam2 = np.arange(dl, 1+dl, dl)
lam_grid = np.array(np.meshgrid(lam1, lam2)).T.reshape(-1, 2)

import pandas as pd
from scipy.interpolate import griddata
from matplotlib.colors import ListedColormap
# Create a dataframe to store the results

def compute_errors_for_lambda_grid(models, model_names, lam_grid, allen_pts, device, x, y, name):
    if os.path.exists(f'multi/plotsjoint/results_u.csv'):
        results_u = pd.read_csv(f'multi/plotsjoint/results_u.csv', index_col=0)
        results_f = pd.read_csv(f'multi/plotsjoint/results_f.csv', index_col=0)
    else:
        results_u = []
        results_f = []

        # For each couple of lambda_1, lambda_2, compute the solution
        for lam in lam_grid:
            print(f'lambda: {lam}')
            # Add the lambda column
            lam_in = np.tile(lam, (allen_pts.shape[0], 1))
            x_in = np.column_stack((allen_pts, lam_in))
            x_in = torch.tensor(x_in, dtype=torch.float32).to(device)
            
            u_preds = {}
            f_preds = {}
            errors_u = {}
            errors_f = {}

            for model, name in zip(models, model_names):
                u_pred = model(x_in).reshape(-1)
                u_preds[name] = u_pred.cpu().detach().numpy()
                f_pred = model.evaluate_forcing(x_in)
                f_preds[name] = f_pred.cpu().detach().numpy()

            # Compute the true solution
            u_true = allen_cahn_true_multi(allen_pts, lam)
            f_true = allen_cahn_forcing_multi(allen_pts, lam)

            # Compute the errors
            for name in model_names:
                errors_u[name] = np.linalg.norm(u_true - u_preds[name], 2) / np.linalg.norm(u_true, 2)
                errors_f[name] = np.linalg.norm(f_true - f_preds[name], 2) / np.linalg.norm(f_true, 2)

            # Add the results to the dataframe
            results_u.append(pd.DataFrame([[lam[0], lam[1]] + [errors_u[name] for name in model_names]], columns=['lambda_1', 'lambda_2'] + model_names))
            results_f.append(pd.DataFrame([[lam[0], lam[1]] + [errors_f[name] for name in model_names]], columns=['lambda_1', 'lambda_2'] + model_names))

        results_u = pd.concat(results_u)
        results_f = pd.concat(results_f)

        # Save the results
        results_u.to_csv(f'multi/plotsjoint/results_u.csv')
        results_f.to_csv(f'multi/plotsjoint/results_f.csv')

    return results_u, results_f

if os.path.exists(f'multi/plotsjoint/results_u.csv'):
    results_u = pd.read_csv(f'multi/plotsjoint/results_u.csv', index_col=0)
    results_f = pd.read_csv(f'multi/plotsjoint/results_f.csv', index_col=0)
else:
    results_u, results_f = compute_errors_for_lambda_grid(
        model_list,
        model_names,
        lam_grid, allen_pts, device, x, y, name
    )

avg_errors_u = results_u[model_names].mean()
avg_errors_f = results_f[model_names].mean()
std_errors_u = results_u[model_names].std()
std_errors_f = results_f[model_names].std()

print("Average Errors for u:")
print(avg_errors_u)

print("Standard Deviation for u:")
print(std_errors_u)

print("Average Errors for f:")
print(avg_errors_f)

print("Standard Deviation for f:")
print(std_errors_f)

# Add standard deviation to the dataframe
results_u.loc['mean'] = avg_errors_u
results_u.loc['std'] = std_errors_u
results_f.loc['mean'] = avg_errors_f
results_f.loc['std'] = std_errors_f

print(results_u)
import matplotlib.pyplot as plt

# Determine the model with the lowest error for each square in the lambda_1, lambda_2 domain
def get_best_model(df, model_names):
    df['best_model'] = df[model_names].idxmin(axis=1)
    return df

results_u_nopinn = results_u.drop(columns=['PINN'])
results_f_nopinn = results_f.drop(columns=['PINN'])


df_u = get_best_model(results_u, model_names)
df_f = get_best_model(results_f, model_names)

df_u_nopinn = get_best_model(results_u_nopinn, [n for n in model_names if n != 'PINN'])
df_f_nopinn = get_best_model(results_f_nopinn, [n for n in model_names if n != 'PINN'])

df_u.to_csv(f'multi/plotsjoint/results_u_best.csv')
df_f.to_csv(f'multi/plotsjoint/results_f_best.csv')

fig, ax = plt.subplots(figsize=(8, 8))

# Create grid data for lambda_1 and lambda_2
lambda_1 = np.linspace(df_u['lambda_1'].min(), df_u['lambda_1'].max(), 100)
lambda_2 = np.linspace(df_u['lambda_2'].min(), df_u['lambda_2'].max(), 100)
lambda_1, lambda_2 = np.meshgrid(lambda_1, lambda_2)

# Interpolate best model data
# Define one hot encoding for the models
model_to_one_hot = {name: i for i, name in enumerate(model_names)}
df_u['best_model_one_hot'] = df_u['best_model'].map(model_to_one_hot)

# Interpolate best model data
best_model_u = griddata((df_u['lambda_1'], df_u['lambda_2']), df_u['best_model_one_hot'], (lambda_1, lambda_2), method='nearest')

# Plot the best model regions
c = ax.pcolormesh(lambda_1, lambda_2, best_model_u, cmap='viridis', shading='auto', vmin=0, vmax=len(model_names)-1)

# Create a legend with color and model name
cmap = plt.get_cmap('viridis')
colors = [cmap(model_to_one_hot[name] / (len(model_names) - 1)) for name in model_names]
patches = [plt.plot([], [], marker="s", ms=10, ls="", mec=None, color=colors[model_to_one_hot[name]], 
            label="{:s}".format(name))[0] for name in model_names]
plt.legend(handles=patches, loc='upper left', ncol=1)

ax.set_xlabel(r'$\lambda_1$')
ax.set_ylabel(r'$\lambda_2$')
ax.set_title('Best Model Regions for Error on u')

plt.savefig(f'multi/plotsjoint/best_model_u.pdf')

fig, ax = plt.subplots(figsize=(8, 8))
df_f['best_model_one_hot'] = df_f['best_model'].map(model_to_one_hot)
# Interpolate best model data
best_model_f = griddata((df_f['lambda_1'], df_f['lambda_2']), df_f['best_model_one_hot'], (lambda_1, lambda_2), method='nearest')

# Plot the best model regions
c = ax.pcolormesh(lambda_1, lambda_2, best_model_f, cmap='viridis', shading='auto', vmin=0, vmax=len(model_names)-1)

# Create a legend with color and model name
cmap = plt.get_cmap('viridis')
colors = [cmap(model_to_one_hot[name] / (len(model_names) - 1)) for name in model_names]
patches = [plt.plot([], [], marker="s", ms=10, ls="", mec=None, color=colors[model_to_one_hot[name]], 
            label="{:s}".format(name))[0] for name in model_names]
plt.legend(handles=patches, loc='upper left', ncol=1)

ax.set_xlabel(r'$\lambda_1$')
ax.set_ylabel(r'$\lambda_2$')
ax.set_title('Best Model Regions for Error on f')

plt.savefig(f'multi/plotsjoint/best_model_f.pdf')


model_names_nopinn = [n for n in model_names if n != 'PINN']
### NO PINN case
fig, ax = plt.subplots(figsize=(8, 8))

# Create grid data for lambda_1 and lambda_2
lambda_1 = np.linspace(df_u_nopinn['lambda_1'].min(), df_u_nopinn['lambda_1'].max(), 100)
lambda_2 = np.linspace(df_u_nopinn['lambda_2'].min(), df_u_nopinn['lambda_2'].max(), 100)
lambda_1, lambda_2 = np.meshgrid(lambda_1, lambda_2)

# Interpolate best model data
# Define one hot encoding for the models

model_to_one_hot_nopinn = {name: i for i, name in enumerate(model_names_nopinn)}
df_u_nopinn['best_model_one_hot'] = df_u_nopinn['best_model'].map(model_to_one_hot_nopinn)

# Interpolate best model data
best_model_u_nopinn = griddata((df_u_nopinn['lambda_1'], df_u_nopinn['lambda_2']), df_u_nopinn['best_model_one_hot'], (lambda_1, lambda_2), method='nearest')

# Plot the best model regions
c = ax.pcolormesh(lambda_1, lambda_2, best_model_u_nopinn, cmap='viridis', shading='auto', vmin=0, vmax=len(model_to_one_hot_nopinn)-1)

# Create a legend with color and model name
cmap = plt.get_cmap('viridis')
colors = [cmap(model_to_one_hot_nopinn[name] / (len(model_to_one_hot_nopinn) - 1)) for name in model_to_one_hot_nopinn]
patches = [plt.plot([], [], marker="s", ms=10, ls="", mec=None, color=colors[model_to_one_hot_nopinn[name]], 
            label="{:s}".format(name))[0] for name in model_to_one_hot_nopinn]
plt.legend(handles=patches, loc='upper left', ncol=1)

ax.set_xlabel(r'$\lambda_1$')
ax.set_ylabel(r'$\lambda_2$')
ax.set_title('Best Model for Error on u')

plt.savefig(f'multi/plotsjoint/best_model_u_nopinn.pdf')

fig, ax = plt.subplots(figsize=(8, 8))
df_f_nopinn['best_model_one_hot'] = df_f_nopinn['best_model'].map(model_to_one_hot_nopinn)
# Interpolate best model data
best_model_f_nopinn = griddata((df_f_nopinn['lambda_1'], df_f_nopinn['lambda_2']), df_f_nopinn['best_model_one_hot'], (lambda_1, lambda_2), method='nearest')

# Plot the best model regions
c = ax.pcolormesh(lambda_1, lambda_2, best_model_f_nopinn, cmap='viridis', shading='auto', vmin=0, vmax=(len(model_to_one_hot_nopinn)-1))

# Create a legend with color and model name
cmap = plt.get_cmap('viridis')
colors = [cmap(model_to_one_hot_nopinn[name] / (len(model_to_one_hot_nopinn) - 1)) for name in model_to_one_hot_nopinn]
patches = [plt.plot([], [], marker="s", ms=10, ls="", mec=None, color=colors[model_to_one_hot_nopinn[name]], 
            label="{:s}".format(name))[0] for name in model_to_one_hot_nopinn]
plt.legend(handles=patches, loc='upper left', ncol=1)

ax.set_xlabel(r'$\lambda_1$')
ax.set_ylabel(r'$\lambda_2$')
ax.set_title('Best Model for PDE residual')

plt.savefig(f'multi/plotsjoint/best_model_f_nopinn.pdf')

plt.rcParams.update({'font.size': 20})
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

# Create grid data for lambda_1 and lambda_2
lambda_1 = np.linspace(df_u['lambda_1'].min(), df_u['lambda_1'].max(), 100)
lambda_2 = np.linspace(df_u['lambda_2'].min(), df_u['lambda_2'].max(), 100)
lambda_1, lambda_2 = np.meshgrid(lambda_1, lambda_2)

# Interpolate best model data for u
df_u['best_model_one_hot'] = df_u['best_model'].map(model_to_one_hot)
best_model_u = griddata((df_u['lambda_1'], df_u['lambda_2']), df_u['best_model_one_hot'], (lambda_1, lambda_2), method='nearest')

# Plot the best model regions for u
c1 = ax1.pcolormesh(lambda_1, lambda_2, best_model_u, cmap='viridis', shading='auto', vmin=0, vmax=len(model_names)-1)
ax1.set_xlabel(r'$\lambda_1$')
ax1.set_ylabel(r'$\lambda_2$')
ax1.set_title('Best Model Regions for Error on u')

# Interpolate best model data for f
df_f['best_model_one_hot'] = df_f['best_model'].map(model_to_one_hot)
best_model_f = griddata((df_f['lambda_1'], df_f['lambda_2']), df_f['best_model_one_hot'], (lambda_1, lambda_2), method='nearest')

# Plot the best model regions for f
c2 = ax2.pcolormesh(lambda_1, lambda_2, best_model_f, cmap='viridis', shading='auto', vmin=0, vmax=len(model_names)-1)
ax2.set_xlabel(r'$\xi_1$')
ax2.set_ylabel(r'$\xi_2$')
ax2.set_title('Best Model Regions for Error on f')

# Create a legend with color and model name
cmap = plt.get_cmap('viridis')
colors = ['royalblue', 'tomato', 'orange', 'limegreen', 'yellow']
patches = [plt.plot([], [], marker="s", ms=10, ls="", mec=None, color=colors[model_to_one_hot[name]], 
            label="{:s}".format(name))[0] for name in model_names]
fig.legend(handles=patches, loc='upper center', ncol=len(model_names))

plt.savefig(f'multi/plotsjoint/best_model_combined.pdf')

plt.rcParams.update({'font.size': 20})
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

# Create grid data for lambda_1 and lambda_2
lambda_1 = np.linspace(df_u_nopinn['lambda_1'].min(), df_u_nopinn['lambda_1'].max(), 100)
lambda_2 = np.linspace(df_u_nopinn['lambda_2'].min(), df_u_nopinn['lambda_2'].max(), 100)
lambda_1, lambda_2 = np.meshgrid(lambda_1, lambda_2)

# Interpolate best model data for u without PINN
df_u_nopinn['best_model_one_hot'] = df_u_nopinn['best_model'].map(model_to_one_hot_nopinn)
best_model_u_nopinn = griddata((df_u_nopinn['lambda_1'], df_u_nopinn['lambda_2']), df_u_nopinn['best_model_one_hot'], (lambda_1, lambda_2), method='nearest')

# Plot the best model regions for u without PINN
colors = ['royalblue', 'tomato', 'orange', 'limegreen']
c1 = ax1.pcolormesh(lambda_1, lambda_2, best_model_u_nopinn, cmap=ListedColormap(colors), shading='auto', vmin=0, vmax=len(model_to_one_hot_nopinn)-1)
ax1.set_xlabel(r'$\xi_1$')
ax1.set_ylabel(r'$\xi_2$')

# Interpolate best model data for f without PINN
df_f_nopinn['best_model_one_hot'] = df_f_nopinn['best_model'].map(model_to_one_hot_nopinn)
best_model_f_nopinn = griddata((df_f_nopinn['lambda_1'], df_f_nopinn['lambda_2']), df_f_nopinn['best_model_one_hot'], (lambda_1, lambda_2), method='nearest')

# Plot the best model regions for f without PINN
c2 = ax2.pcolormesh(lambda_1, lambda_2, best_model_f_nopinn, cmap=ListedColormap(colors), shading='auto', vmin=0, vmax=len(model_to_one_hot_nopinn)-1)
ax2.set_xlabel(r'$\xi_1$')
ax2.set_ylabel(r'$\xi_2$')

# Create a legend with color and model name
cmap = plt.get_cmap('viridis')
colors = ['royalblue', 'tomato', 'orange', 'limegreen']
patches = [plt.plot([], [], marker="s", ms=10, ls="", mec=None, color=colors[model_to_one_hot_nopinn[name]], 
            label="{:s}".format(name))[0] for name in model_to_one_hot_nopinn]
fig.legend(handles=patches, loc='upper center', ncol=len(model_to_one_hot_nopinn), bbox_to_anchor=(0.5, 1.01))

plt.savefig(f'multi/plotsjoint/best_model_combined_nopinn.pdf')