import os
import sys

os.environ['CUDA_VISIBLE_DEVICES']='2'
import torch
import torch.nn as nn
import numpy as np
from typing import List
from tqdm import tqdm
import scipy.io as io
from sklearn.decomposition import PCA

class opnn(nn.Module):
    def __init__(self, branch1_dim, branch2_dim, trunk_dim):
        super(opnn, self).__init__()
        static_init(self, branch1_dim, branch2_dim, trunk_dim)
    
    def forward(self, f, f_bc, x):
        return static_forward(self, f, f_bc, x)
    
    def loss(self, f, f_bc, x, y):
        return static_loss(self, f, f_bc, x, y)

class DIMON_opnn(nn.Module):
    def __init__(self, branch1_dim, branch2_dim, trunk_dim):
        super(DIMON_opnn, self).__init__()
        DIMON_static_init(self, branch1_dim, branch2_dim, trunk_dim)
    
    def forward(self, f, f_bc, x):
        return DIMON_static_forward(self, f, f_bc, x)
    
    def loss(self, f, f_bc, x, y):
        return DIMON_static_loss(self, f, f_bc, x, y)

def static_forward(model, f, f_bc, x):
    y_br1 = model._branch1(f)
    y_br2 = model._branch2(f_bc)
    y_br = y_br1 * y_br2
    y_tr = model._trunk(x)
    y_out = torch.einsum("ij,kj->ik", y_br, y_tr)
    return y_out

def static_loss(model, f, f_bc, x, y):
    y_out = model.forward(f, f_bc, x)
    weights = torch.abs(y) + 1.0
    loss = (weights * (y_out - y)**2).mean()
    return loss

def static_init(model, branch1_dim, branch2_dim, trunk_dim):
    model.z_dim = trunk_dim[-1]
    model._branch1 = nn.Sequential(*[nn.Sequential(nn.Linear(branch1_dim[i], branch1_dim[i+1]), nn.Tanh()) for i in range(len(branch1_dim)-1)])
    model._branch2 = nn.Sequential(*[nn.Sequential(nn.Linear(branch2_dim[i], branch2_dim[i+1]), nn.Tanh()) for i in range(len(branch2_dim)-1)])
    model._trunk = nn.Sequential(*[nn.Sequential(nn.Linear(trunk_dim[i], trunk_dim[i+1]), nn.Tanh()) for i in range(len(trunk_dim)-1)])

def DIMON_static_forward(model, f, f_bc, x):
    y_br1 = model._branch1(f)
    y_br2 = model._branch2(f_bc)
    y_br = y_br1 * y_br2
    y_tr = model._trunk(x)
    y_out = torch.einsum("ij,kj->ik", y_br, y_tr)
    return y_out

def DIMON_static_loss(model, f, f_bc, x, y):
    y_out = model.forward(f, f_bc, x)
    loss = ((y_out - y)**2).mean()
    return loss

def DIMON_static_init(model, branch1_dim, branch2_dim, trunk_dim):
    model.z_dim = trunk_dim[-1]
    model._branch1 = nn.Sequential(*[nn.Sequential(nn.Linear(branch1_dim[i], branch1_dim[i+1]), nn.Tanh()) for i in range(len(branch1_dim)-1)])
    model._branch2 = nn.Sequential(*[nn.Sequential(nn.Linear(branch2_dim[i], branch2_dim[i+1]), nn.Tanh()) for i in range(len(branch2_dim)-1)])
    model._trunk = nn.Sequential(*[nn.Sequential(nn.Linear(trunk_dim[i], trunk_dim[i+1]), nn.Tanh()) for i in range(len(trunk_dim)-1)])

def static_main(model, epochs, device):
    PODMode = 10
    num_bc = 68
    dim_br1 = [PODMode*2, 100, 100, 100]
    dim_br2 = [num_bc, 150, 150, 150, 100]
    dim_tr = [2, 100, 100, 100]
    num_train = 3300
    num_test = 200

    datafile = "Laplace/dataset/Laplace_data.mat"
    dataset = io.loadmat(datafile)
    x = dataset["x_uni"]
    x_mesh = dataset["x_mesh_data"]
    dx = x_mesh - x
    u = dataset["u_data"]
    u_bc = dataset["u_bc"]
    u_bc = u_bc[:, ::3]

    datafile_supp = "Laplace/dataset/Laplace_data_supp.mat"
    dataset_supp = io.loadmat(datafile_supp)
    x_mesh_supp = dataset_supp["x_mesh_data"]
    dx_supp = x_mesh_supp - x
    u_supp = dataset_supp["u_data"]
    u_bc_supp = dataset_supp["u_bc"]
    u_bc_supp = u_bc_supp[:, ::3]

    datafile_supp2 = "Laplace/dataset/Laplace_data_supp2000.mat"
    dataset_supp2 = io.loadmat(datafile_supp2)
    x_mesh_supp2 = dataset_supp2["x_mesh_data"]
    dx_supp2 = x_mesh_supp2 - x
    u_supp2 = dataset_supp2["u_data"]
    u_bc_supp2 = dataset_supp2["u_bc"]
    u_bc_supp2 = u_bc_supp2[:, ::3]

    x_mesh = np.concatenate((x_mesh, x_mesh_supp, x_mesh_supp2), axis=0)
    dx = np.concatenate((dx, dx_supp, dx_supp2), axis=0)
    u = np.concatenate((u, u_supp, u_supp2), axis=0)
    u_bc = np.concatenate((u_bc, u_bc_supp, u_bc_supp2), axis=0)

    dx_train = dx[:num_train]
    dx_test = dx[-num_test:]
    dx1_train = dx_train[:, :, 0]
    dx2_train = dx_train[:, :, 1]

    pca_x = PCA(n_components=PODMode)
    pca_x.fit(dx1_train - dx1_train.mean(axis=0))
    coeff_x_train = pca_x.transform(dx1_train - dx1_train.mean(axis=0))
    coeff_x_test = pca_x.transform(dx_test[:, :, 0] - dx1_train.mean(axis=0))

    pca_y = PCA(n_components=PODMode)
    pca_y.fit(dx2_train - dx2_train.mean(axis=0))
    coeff_y_train = pca_y.transform(dx2_train - dx2_train.mean(axis=0))
    coeff_y_test = pca_y.transform(dx_test[:, :, 1] - dx2_train.mean(axis=0))

    f_train = np.concatenate((coeff_x_train, coeff_y_train), axis=1)
    f_test = np.concatenate((coeff_x_test, coeff_y_test), axis=1)

    u_train = u[:num_train]
    u_test = u[-num_test:]
    f_bc_train = u_bc[:num_train, :]
    f_bc_test = u_bc[-num_test:, :]

    u_test_tensor = torch.tensor(u_test, dtype=torch.float).to(device)
    f_test_tensor = torch.tensor(f_test, dtype=torch.float).to(device)
    f_bc_test_tensor = torch.tensor(f_bc_test, dtype=torch.float).to(device)
    u_train_tensor = torch.tensor(u_train, dtype=torch.float).to(device)
    f_train_tensor = torch.tensor(f_train, dtype=torch.float).to(device)
    f_bc_train_tensor = torch.tensor(f_bc_train, dtype=torch.float).to(device)
    x_tensor = torch.tensor(x, dtype=torch.float).to(device)

    model = model.to(device)
    model = model.float()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    train_loss = np.zeros((epochs, 1))

    def train(epoch, f, f_bc, x, y):
        model.train()
        def closure():
            optimizer.zero_grad()
            loss = model.loss(f, f_bc, x, y)
            train_loss[epoch, 0] = loss.item()
            loss.backward()
            return loss
        optimizer.step(closure)

    for epoch in tqdm(range(0, epochs), desc='Epoch: 0', unit='epoch', leave=True):
        if epoch == 10000:
            optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
        elif epoch == 90000:
            optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
        elif epoch == epochs - 1000:
            optimizer = torch.optim.LBFGS(model.parameters())
        train(epoch, f_train_tensor, f_bc_train_tensor, x_tensor, u_train_tensor)

    u_test_pred = model.forward(f_test_tensor, f_bc_test_tensor, x_tensor).detach().cpu().numpy()
    u_train_pred = model.forward(f_train_tensor, f_bc_train_tensor, x_tensor).detach().cpu().numpy()
    rel_l2_err_train = (np.linalg.norm(u_train_pred - u_train, axis=1)/np.linalg.norm(u_train, axis=1))
    mean_abs_err = (abs(u_test_pred - u_test)).mean(axis=1)
    rel_l2_err = np.linalg.norm(u_test - u_test_pred, axis=1)/np.linalg.norm(u_test, axis=1)

    return mean_abs_err, rel_l2_err, rel_l2_err_train, x_mesh, u_test_pred, u_test

mean_abs_err, rel_l2_err, rel_l2_err_train, x_mesh, u_pred, u_true = static_main(opnn([20, 100, 100, 100], [68, 150, 150, 150, 100], [2, 100, 100, 100]), 50000, 'cuda')
dimon_mean_abs_err, dimon_rel_l2_err, dimon_rel_l2_err_train, dimon_x_mesh, dimon_u_pred, dimon_u_true = static_main(DIMON_opnn([20, 100, 100, 100], [68, 150, 150, 150, 100], [2, 100, 100, 100]), 50000, 'cuda')

np.savetxt('/data/zz/rag_paper/experiment_auto_write/result/dimon/angle/analysis_angle_047/draw_data/mean_abs_err.txt', mean_abs_err)
np.savetxt('/data/zz/rag_paper/experiment_auto_write/result/dimon/angle/analysis_angle_047/draw_data/rel_l2_err.txt', rel_l2_err)
np.savetxt('/data/zz/rag_paper/experiment_auto_write/result/dimon/angle/analysis_angle_047/draw_data/dimon_mean_abs_err.txt', dimon_mean_abs_err)
np.savetxt('/data/zz/rag_paper/experiment_auto_write/result/dimon/angle/analysis_angle_047/draw_data/dimon_rel_l2_err.txt', dimon_rel_l2_err)
