import torch
import torch.nn as nn
import numpy as np
import joblib
from torch.autograd import grad
import pandas as pd
from scipy.interpolate import interp1d
import time
import pybamm
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from sklearn.model_selection import train_test_split
import seaborn as sns

class D_s_p(nn.Module):
    def __init__(self, layers):
        super(D_s_p, self).__init__()
        self.net = nn.Sequential()
        for i in range(len(layers) - 1):
            self.net.add_module(f"layer_{i}", nn.Linear(layers[i], layers[i + 1]))
            if i < len(layers) - 2:
                self.net.add_module(f"tanh_{i}", nn.Tanh())

    def forward(self, I, V, T):
        inputs = torch.cat([I, V, T], dim=1)
        output = self.net(inputs)
        return torch.sigmoid(output) 
    
class D_s_n(nn.Module):
    def __init__(self, layers):
        super(D_s_n, self).__init__()
        self.net = nn.Sequential()
        for i in range(len(layers)-1):
            self.net.add_module(f"layer_{i}", nn.Linear(layers[i], layers[i+1]))
            if i < len(layers)-2:
                self.net.add_module(f"tanh_{i}", nn.Tanh())

    def forward(self, I, V, T):
        inputs = torch.cat([I, V, T], dim=1)
        return torch.sigmoid(self.net(inputs))
    
class D_e(nn.Module):
    def __init__(self, layers):
        super(D_e, self).__init__()
        self.net = nn.Sequential()
        for i in range(len(layers)-1):
            self.net.add_module(f"layer_{i}", nn.Linear(layers[i], layers[i+1]))
            if i < len(layers)-2:
                self.net.add_module(f"tanh_{i}", nn.Tanh())

    def forward(self, I, V, T):
        inputs = torch.cat([I, V, T], dim=1)
        return torch.sigmoid(self.net(inputs))


class PINN_ParameterID(nn.Module):
    def __init__(self, layers):
        super(PINN_ParameterID, self).__init__()
        self.net = nn.Sequential()
        for i in range(len(layers)-1):
            self.net.add_module(f"layer_{i}", nn.Linear(layers[i], layers[i+1]))
            if i < len(layers)-2:
                self.net.add_module(f"tanh_{i}", nn.Tanh())
        self.net.add_module("sigmoid", nn.Sigmoid())

    def forward(self, t, x_p, x_s, x_sep, r_p, r_s):
        inputs = torch.cat([t, x_p, x_s, x_sep, r_p, r_s], dim=1)
        outputs = self.net(inputs)
        c_s_pos = outputs[:, 0:1]
        c_s_pos = 16613.9 * c_s_pos + 30970.1 
        c_s_neg = outputs[:, 1:2]
        c_s_neg = 31507 * c_s_neg
        c_e_pos = outputs[:, 2:3]
        c_e_neg = outputs[:, 3:4]
        c_e_sep = outputs[:, 4:5]
        phi_s_pos = outputs[:, 5:6]
        phi_s_pos = 2.8 + phi_s_pos
        phi_s_neg = outputs[:, 6:7]
        phi_e_pos = outputs[:, 7:8]
        phi_e_neg = outputs[:, 8:9]
        phi_e_sep = outputs[:, 9:10]
        
        return c_s_pos, c_s_neg, c_e_pos, c_e_neg, c_e_sep, phi_s_pos, phi_s_neg, phi_e_pos, phi_e_neg, phi_e_sep


class P2D_Physics_ID:
    def __init__(self, pinn_model):
        self.pinn = pinn_model
        self.F = 96485
        self.R = 8.314
        self.t_plus = 0.363
        self.epsilon_e_pos = 1.0
        self.epsilon_e_neg = 1.0
        self.epsilon_e_sep = 1.0
        self.epsilon_s_pos = 0.297
        self.epsilon_s_neg = 0.471
        self.Gamma_s = 13e-6
        self.a_s_pos = 3 * self.epsilon_s_pos / self.Gamma_s
        self.a_s_neg = 3 * self.epsilon_s_neg / self.Gamma_s

        self.L_pos = 4.0e-5
        self.L_neg = 3.9e-5
        self.L_sep = 0.7e-5

        layers1 = [3, 8, 8, 8, 1] 
        layers2 = [3, 8, 8, 8, 1] 
        layers3 = [3, 8, 8, 8, 1] 

        self.D_s_p_ = D_s_p(layers1) 
        self.D_s_n_ = D_s_n(layers2) 
        self.D_e_ = D_e(layers3) 
        self.sigma_pos = torch.tensor(10, dtype=torch.float32) 
        self.sigma_neg = torch.tensor(215, dtype=torch.float32) 
        m_ref_pos = 1 * 2 * 10 ** (-11) * self.F  

        E_r_pos = 5000
        arrhenius_pos = np.exp(E_r_pos / self.R * (1 / 298.15 - 1 / 300.15))
        self.kappa_pos = torch.tensor(m_ref_pos * arrhenius_pos)
        m_ref_neg = 1e-6  
        E_r_neg = 35000
        arrhenius_neg = np.exp(E_r_neg / self.R * (1 / 298.15 - 1 / 300.15))
        self.kappa_neg = torch.tensor(m_ref_neg * arrhenius_neg)
        self.c_s_max_pos = torch.tensor(47854, dtype=torch.float32) 
        self.c_s_max_neg = torch.tensor(31507, dtype=torch.float32) 
        self.sigma_pos.requires_grad = True
        self.sigma_neg.requires_grad = True
        self.kappa_pos.requires_grad = True
        self.kappa_neg.requires_grad = True
        self.c_s_max_neg.requires_grad = True
        self.c_s_max_pos.requires_grad = True

    def get_loss(self, t, x_p, x_n, x_sep, r_p, r_n, V_data, T, I):

        t.requires_grad_(True)
        x_p.requires_grad_(True)
        x_n.requires_grad_(True)
        x_sep.requires_grad_(True)
        r_p.requires_grad_(True)
        r_n.requires_grad_(True)

        c_s_pos, c_s_neg, c_e_pos, c_e_neg, c_e_sep, phi_s_pos, phi_s_neg, phi_e_pos, phi_e_neg, phi_e_sep = self.pinn(t, x_p, x_n, x_sep, r_p, r_n)

        D_s_pos = self.D_s_p_(I, V_data, T)*5e-13
        D_s_neg = self.D_s_n_(I, V_data, T)*1e-13 
        D_e = self.D_e_(I, V_data, T)*8e-11 
        D_e_eff_pos = self.D_e_(I, V_data, T) * 0.357 ** 1.5
        D_e_eff_neg = self.D_e_(I, V_data, T) * 0.444 ** 1.5
        D_e_eff_sep = self.D_e_(I, V_data, T)
        sigma_pos_eff = self.sigma_pos * self.epsilon_s_pos ** 1.5
        sigma_neg_eff = self.sigma_neg * self.epsilon_s_neg ** 1.5 
        c_s_max_pos = self.c_s_max_pos
        c_s_max_neg = self.c_s_max_neg

        j_pos = - I / (self.F * self.a_s_pos * self.L_pos)
        j_neg = - I / (self.F * self.a_s_neg * self.L_neg)
        j_sep = - I / (self.F * (self.a_s_pos + self.a_s_neg) * self.L_sep / 2)

        dc_s_pos_dr = grad(c_s_pos, r_p, grad_outputs=torch.ones_like(c_s_pos), create_graph=True)[0]
        dc_s_neg_dr = grad(c_s_neg, r_n, grad_outputs=torch.ones_like(c_s_neg), create_graph=True)[0]
        dc_e_pos_dx = grad(c_e_pos, x_p, grad_outputs=torch.ones_like(c_e_pos), create_graph=True)[0]
        dc_e_neg_dx = grad(c_e_neg, x_n, grad_outputs=torch.ones_like(c_e_neg), create_graph=True)[0]
        dc_e_sep_dx = grad(c_e_sep, x_sep, grad_outputs=torch.ones_like(c_e_sep), create_graph=True)[0]
        dphi_s_pos_dx = grad(phi_s_pos, x_p, grad_outputs=torch.ones_like(phi_s_pos), create_graph=True)[0]
        dphi_s_neg_dx = grad(phi_s_neg, x_n, grad_outputs=torch.ones_like(phi_s_neg), create_graph=True)[0]
        dphi_e_pos_dx = grad(phi_e_pos, x_p, grad_outputs=torch.ones_like(phi_e_pos), create_graph=True)[0]
        dphi_e_neg_dx = grad(phi_e_neg, x_n, grad_outputs=torch.ones_like(phi_e_neg), create_graph=True)[0]
        dphi_e_sep_dx = grad(phi_e_sep, x_sep, grad_outputs=torch.ones_like(phi_e_sep), create_graph=True)[0]

        dc_s_pos_dt = grad(c_s_pos, t, grad_outputs=torch.ones_like(c_s_pos), create_graph=True)[0]
        term_r_pos = grad(r_p**2 * dc_s_pos_dr, r_p, grad_outputs=torch.ones_like(c_s_pos), create_graph=True)[0]
        residual_solid_pos = dc_s_pos_dt - (D_s_pos / r_p**2) * term_r_pos

        dc_s_neg_dt = grad(c_s_neg, t, grad_outputs=torch.ones_like(c_s_neg), create_graph=True)[0]
        term_r_neg = grad(r_n**2 * dc_s_neg_dr, r_n, grad_outputs=torch.ones_like(c_s_neg), create_graph=True)[0]
        residual_solid_neg = dc_s_neg_dt - (D_s_neg / r_n**2) * term_r_neg

        dc_e_pos_dt = grad(c_e_pos, t, grad_outputs=torch.ones_like(c_e_pos), create_graph=True)[0]
        d2c_e_pos_dx2 = grad(dc_e_pos_dx, x_p, grad_outputs=torch.ones_like(c_e_pos), create_graph=True)[0]
        residual_liquid_pos = self.epsilon_e_pos * dc_e_pos_dt - D_e_eff_pos * d2c_e_pos_dx2 - (1 - self.t_plus) * self.a_s_pos * j_pos
        
        dc_e_neg_dt = grad(c_e_neg, t, grad_outputs=torch.ones_like(c_e_neg), create_graph=True)[0]
        d2c_e_neg_dx2 = grad(dc_e_neg_dx, x_n, grad_outputs=torch.ones_like(c_e_neg), create_graph=True)[0]
        residual_liquid_neg = self.epsilon_e_neg * dc_e_neg_dt - D_e_eff_neg * d2c_e_neg_dx2 - (1 - self.t_plus) * self.a_s_neg * j_neg

        dc_e_sep_dt = grad(c_e_sep, t, grad_outputs=torch.ones_like(c_e_sep), create_graph=True)[0]
        d2c_e_sep_dx2 = grad(dc_e_sep_dx, x_sep, grad_outputs=torch.ones_like(c_e_sep), create_graph=True)[0]
        residual_liquid_sep = self.epsilon_e_sep * dc_e_sep_dt - D_e_eff_sep * d2c_e_sep_dx2 - (1 - self.t_plus) * (self.a_s_pos+self.a_s_neg)/2 * j_sep

        d2phi_s_pos_dx2 = grad(dphi_s_pos_dx, x_p, grad_outputs=torch.ones_like(phi_s_pos), create_graph=True)[0]
        residual_phi_s_pos = sigma_pos_eff * d2phi_s_pos_dx2 + j_pos * self.a_s_pos * self.F
        residual_phi_s_pos = residual_phi_s_pos/100000

        d2phi_s_neg_dx2 = grad(dphi_s_neg_dx, x_n, grad_outputs=torch.ones_like(phi_s_neg), create_graph=True)[0]
        residual_phi_s_neg = sigma_neg_eff * d2phi_s_neg_dx2 + j_neg * self.a_s_neg * self.F
        residual_phi_s_neg = residual_phi_s_neg/100000

        sigma_e_pos = (0.1297 * (c_e_pos.abs() / 1000) ** 3 - 2.51 * (c_e_pos.abs() / 1000) ** 1.5 + 3.329 * (c_e_pos.abs() / 1000))
        d2phi_e_pos_dx2 = grad(dphi_e_pos_dx, x_p, grad_outputs=torch.ones_like(phi_e_pos), create_graph=True)[0]
        temp_pos = (2 * self.R * T) / self.F * (1 - self.t_plus) * grad(torch.log(c_e_pos), x_p, grad_outputs=torch.ones_like(c_e_pos), create_graph=True)[0] * sigma_e_pos
        residual_phi_e_pos = sigma_e_pos * d2phi_e_pos_dx2 - grad(temp_pos, x_p, grad_outputs=torch.ones_like(temp_pos), create_graph=True)[0] + self.a_s_pos * self.F * j_pos
        residual_phi_e_pos = residual_phi_e_pos/100000

        sigma_e_neg = (0.1297 * (c_e_neg / 1000) ** 3 - 2.51 * (c_e_neg / 1000) ** 1.5 + 3.329 * (c_e_neg / 1000))
        d2phi_e_neg_dx2 = grad(dphi_e_neg_dx, x_n, grad_outputs=torch.ones_like(phi_e_neg), create_graph=True)[0]
        temp_neg = (2 * self.R * T) / self.F * (1 - self.t_plus) * grad(torch.log(c_e_neg), x_n, grad_outputs=torch.ones_like(c_e_neg), create_graph=True)[0] * sigma_e_neg
        residual_phi_e_neg = sigma_e_neg * d2phi_e_neg_dx2 - grad(temp_neg, x_n, grad_outputs=torch.ones_like(temp_neg), create_graph=True)[0] + self.a_s_neg * self.F * j_neg
        residual_phi_e_neg = residual_phi_e_neg/100000

        sigma_e_sep = (0.1297 * (c_e_sep / 1000) ** 3 - 2.51 * (c_e_sep / 1000) ** 1.5 + 3.329 * (c_e_sep / 1000))
        d2phi_e_sep_dx2 = grad(dphi_e_sep_dx, x_sep, grad_outputs=torch.ones_like(phi_e_sep), create_graph=True)[0]
        temp_sep = (2 * self.R * T) / self.F * (1 - self.t_plus) * grad(torch.log(c_e_sep), x_sep, grad_outputs=torch.ones_like(c_e_sep), create_graph=True)[0] * sigma_e_sep
        residual_phi_e_sep = sigma_e_sep * d2phi_e_sep_dx2 - grad(temp_sep, x_sep, grad_outputs=torch.ones_like(temp_sep), create_graph=True)[0] + self.a_s_neg * self.F * j_sep
        residual_phi_e_sep = residual_phi_e_sep/100000

        data_pos = pd.read_csv("/home/ma-user/work/PINNs/data/lico2_ocp_Experiment.csv")
        x_pos = data_pos.iloc[:, 0].values
        y_pos = data_pos.iloc[:, 1].values

        interpolation_function_pos = interp1d(x_pos, y_pos, kind='cubic', fill_value="extrapolate")

        def lico2_ocp_Enertech_Ai2020(x_input):
            return interpolation_function_pos(x_input)

        data_neg = pd.read_csv("/home/ma-user/work/PINNs/data/graphite_ocp_Enertech_Ai2020_adjusted.csv")
        x_neg = data_neg.iloc[:, 0].values
        y_neg = data_neg.iloc[:, 1].values

        interpolation_function_neg = interp1d(x_neg, y_neg, kind='cubic', fill_value="extrapolate")

        def graphite_ocp_Enertech_Ai2020(x_input):
            return interpolation_function_neg(x_input)
        
        i0_pos = self.kappa_pos * torch.sqrt(c_e_pos.abs() * (c_s_max_pos - c_s_pos)) * (c_s_pos.abs()**0.5)
        i0_neg = self.kappa_neg * (c_e_neg.abs()**0.5) * ((c_s_max_neg - c_s_neg)**0.5) * (c_s_neg.abs()**0.5)
        eta_pos = self.R * T / (0.5 * self.F) * torch.arcsinh(j_pos/i0_pos)
        eta_neg = self.R * T / (0.5 * self.F) * torch.arcsinh(j_neg/i0_neg)
        delta_pos = eta_pos + torch.from_numpy(lico2_ocp_Enertech_Ai2020((c_s_pos/c_s_max_pos).detach().numpy())).float() + j_pos * 0.001
        delta_neg = eta_neg + torch.from_numpy(graphite_ocp_Enertech_Ai2020((c_s_neg/c_s_max_neg).detach().numpy())).float() + j_neg * 0.001
        residual_bv_pos = phi_s_pos - phi_e_pos - delta_pos
        residual_bv_neg = phi_s_neg - phi_e_neg - delta_neg

        V_pred = phi_s_pos - phi_s_neg 

        loss_data = torch.mean((V_pred - V_data).abs())
        ls = V_pred - V_data

        loss_pde = (
            torch.mean(residual_solid_pos.abs()) +
            torch.mean(residual_solid_neg.abs()) +  
            torch.mean(residual_liquid_pos.abs()) +
            torch.mean(residual_liquid_neg.abs()) +
            torch.mean(residual_liquid_sep.abs()) +
            torch.mean(residual_phi_e_pos.abs()) +
            torch.mean(residual_phi_e_neg.abs()) + 
            torch.mean(residual_phi_e_sep.abs()) + 
            torch.mean(residual_phi_s_pos.abs()) +
            torch.mean(residual_phi_s_neg.abs()) +
            torch.mean(residual_bv_pos.abs()) + 
            torch.mean(residual_bv_neg.abs())
        )
        
        loss_total = loss_data + loss_pde
        return V_pred, loss_total, loss_pde, loss_data, D_s_pos, D_s_neg, D_e, torch.mean(residual_solid_neg), torch.mean(residual_solid_pos), torch.mean(residual_liquid_pos), torch.mean(residual_liquid_neg), torch.mean(residual_liquid_sep), torch.mean(residual_phi_e_pos), torch.mean(residual_phi_e_neg), torch.mean(residual_phi_e_sep), torch.mean(residual_phi_s_pos), torch.mean(residual_phi_s_neg), torch.mean(residual_bv_pos), torch.mean(residual_bv_neg)


def train_pinn(physics, epochs, optimizer, t_data, x_p_data, x_n_data, x_sep_data, r_p_data, r_n_data, V_data, T, I):
    loss_data_list = []

    for epoch in range(epochs):
        optimizer.zero_grad()
        V_pred, loss_total, loss_pde, loss_data, D_s_pos, D_s_neg, D_e, residual_solid_neg, residual_solid_pos, residual_liquid_pos, residual_liquid_neg, residual_liquid_sep, residual_phi_e_pos, residual_phi_e_neg, residual_phi_e_sep, residual_phi_s_pos, residual_phi_s_neg, residual_bv_pos, residual_bv_neg = physics.get_loss(
            t_data, x_p_data, x_n_data, x_sep_data, r_p_data, r_n_data, V_data, T, I
        )
        loss_total.backward()
        optimizer.step()
        loss_data_list.append(loss_data.item())
        current_lr = optimizer.param_groups[0]['lr']

        if epoch % 20 == 0:
            print(f'Learning Rate {current_lr}')
            print(f"Epoch {epoch}, Loss_train: {loss_total.item():.4f}")
            print(f"D_s_pos={D_s_pos.mean()}, D_s_neg={D_s_neg.mean()}, D_e={D_e.mean()}")
            print(f"Loss_train_data={loss_data.mean()}")

    return loss_data_list, loss_data.mean()

def test_pinn(physics, t_test_data, x_p_test_data, x_n_test_data, x_sep_test_data, 
              r_p_test_data, r_n_test_data, V_test_data, T_test, I_test):
    loss_data_test_list = []

    V_pred, loss_total_test, loss_pde_test, loss_data_test, D_s_pos_test, D_s_neg_test, D_e_test, _, _, _, _, _, _, _, _, _, _, _, _ = physics.get_loss(
        t_test_data, x_p_test_data, x_n_test_data, x_sep_test_data, 
        r_p_test_data, r_n_test_data, V_test_data, T_test, I_test
    )
    loss_data_test_list.append(loss_data_test.item())

    print(f"D_s_pos_test={D_s_pos_test.mean()}, D_s_neg_test={D_s_neg_test.mean()}, D_e_test={D_e_test.mean()}")
    print(f"Loss_test_data={loss_data_test.mean()}")
    print(f"V_pred={(V_pred).mean()}")
    print(80*"=")

    return loss_data_test_list, loss_data_test.mean(), V_pred, D_s_pos_test.mean(), D_s_neg_test.mean(), D_e_test.mean()

if __name__ == "__main__":

    layers = [6, 16, 32, 16, 10] 
    pinn = PINN_ParameterID(layers)
    physics = P2D_Physics_ID(pinn)

    data = pd.read_csv("/home/ma-user/work/PINNs/data/experiment_1CP_continuous.csv", usecols=[1,2,3,4,5,6,7,10,11])
    tensor_data = torch.tensor(data.values, dtype=torch.float32)
    tensor_data = tensor_data[~torch.isnan(tensor_data).any(dim=1)]

    train_data, test_data = train_test_split(tensor_data.numpy(), test_size=0.9, random_state=42)

    t_data = torch.tensor(train_data[:, 0]).reshape(-1, 1)
    x_n_data = torch.tensor(train_data[:, 2]).reshape(-1, 1)
    x_p_data = torch.tensor(train_data[:, 3]).reshape(-1, 1)
    x_sep_data = torch.tensor(train_data[:, 4]).reshape(-1, 1)
    r_p_data = torch.tensor(train_data[:, 5]).reshape(-1, 1)
    r_n_data = torch.tensor(train_data[:, 6]).reshape(-1, 1)
    V_data = torch.tensor(train_data[:, 8]).reshape(-1, 1) - torch.tensor(train_data[:, 7]).reshape(-1, 1)
    T = 300.15 * torch.ones_like(t_data)

    I = torch.tensor(train_data[:, 1]).reshape(-1, 1)

    t_test_data = torch.linspace(0, 400, steps=401).reshape(-1, 1)
    x_n_test_data = torch.zeros_like(t_test_data)
    x_p_test_data = torch.zeros_like(t_test_data)
    x_sep_test_data = torch.zeros_like(t_test_data)
    r_p_test_data = torch.zeros_like(t_test_data)
    r_n_test_data = torch.zeros_like(t_test_data)
    file_path1 = '/home/ma-user/work/PINNs/interpolated_results.csv'
    data = pd.read_csv(file_path1)

    V_test_data_values = data['Interpolated Value'].values

    V_test_data = torch.tensor(V_test_data_values).reshape(-1, 1).to(t_test_data.dtype)
    T_test = 300.15 * torch.ones_like(t_test_data)
    I_test = data['Interpolated Value Current'].head(401).values  
    I_test = torch.tensor(I_test).reshape(-1, 1).to(t_test_data.dtype)
    def lr_lambda(step):
        if step < 20001:
            return 1.0  
        else:
            return 0.1 ** ((step // 300))  

    optimizer = torch.optim.Adam(
        list(pinn.parameters()) + 
        list(physics.D_s_p_.parameters()) +
        list(physics.D_e_.parameters()) +
        list(physics.D_s_n_.parameters()),
        lr=5e-4
    )
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    start_time = time.time()

    num_iterations = 1

    loss_data_train_total = []
    loss_data_test_total = []
    D_s_p_list = []
    D_s_n_list = []
    D_e_list = []

    for i in range(num_iterations):

        print(f"Iteration {i + 1}/{num_iterations}")
        epochs = 20000

        loss_data_train, loss_last_train = train_pinn(physics, epochs, optimizer, t_data, x_p_data, x_n_data, x_sep_data, r_p_data, r_n_data, V_data, T, I)
        loss_data_test_list, loss_last_test, V_pred, D_s_pos_test, D_s_neg_test, D_e_test = test_pinn(physics, t_test_data, x_p_test_data, x_n_test_data, x_sep_test_data, 
                                      r_p_test_data, r_n_test_data, V_test_data, T_test, I_test)


        loss_data_train_total.append(loss_last_train)
        loss_data_test_total.append(loss_last_test)
        D_s_p_list.append(D_s_pos_test)
        D_s_n_list.append(D_s_neg_test)
        D_e_list.append(D_e_test)
        x_values = np.linspace(0, 400, num=401)

        plt.figure(figsize=(10, 6))
        plt.plot(x_values, V_pred.detach().numpy(), label='V_pred', color='b')
        plt.title('Line Plot of V_pred')
        plt.xlabel('X-axis')
        plt.ylabel('V_pred Values')
        plt.grid()
        plt.legend()
        plt.savefig('V_pred_plot.pdf', format='pdf')
        plt.show()

    train_tensor = torch.tensor(loss_data_train_total, requires_grad=False)
    test_tensor = torch.tensor(loss_data_test_total, requires_grad=False)
    D_s_p_tensor = torch.tensor(D_s_p_list, requires_grad = False)
    D_s_n_tensor = torch.tensor(D_s_n_list, requires_grad = False)
    D_e_tensor = torch.tensor(D_e_list, requires_grad = False)

    mean_loss_train = torch.mean(train_tensor)
    var_loss_train = torch.var(train_tensor)

    mean_loss_test = torch.mean(test_tensor)
    var_loss_test = torch.var(test_tensor)

    mean_D_s_p = torch.mean((D_s_p_tensor-1e-13)/1e-13)
    mean_D_s_n = torch.mean((D_s_n_tensor-3.9e-14)/3.9e-14)
    mean_D_e = torch.mean((D_e_tensor-7.5e-11)/7.5e-11)
    with open('results_continuous_t.txt', 'w') as file:
        file.write(f'mean_D_s_p: {mean_D_s_p.item()}\n')
        file.write(f'mean_D_s_n: {mean_D_s_n.item()}\n')
        file.write(f'mean_D_e: {mean_D_e.item()}\n')

        file.write(f'D_s_p_tensor: {D_s_p_tensor.tolist()}\n')
        file.write(f'D_s_n_tensor: {D_s_n_tensor.tolist()}\n')
        file.write(f'D_e_tensor: {D_e_tensor.tolist()}\n')
        file.write(f'test_tensor: {test_tensor.tolist()}\n')
        file.write(f'V_pred: {V_pred.detach().numpy().tolist()}\n')

    print("Results saved to results.txt")

    print("Mean Training Loss:", mean_loss_train)
    print("Variance in Training Loss:", var_loss_train)
    print("Mean Testing Loss:", mean_loss_test)
    print("Variance in Testing Loss:", var_loss_test)
    print("mean_D_s_p:",mean_D_s_p)
    print("mean_D_s_n:",mean_D_s_n)
    print("mean_D_e:",mean_D_e)
    end_time = time.time()
    print(end_time-start_time) 