import torch
import torch.nn as nn
import torch.optim as optim
import scipy.io as sio
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
import random
import torch.nn.init as init
import torch.nn.functional as F
import time
import torch.nn.init as init
import scipy.io
import os
eps = torch.tensor(1e-05)
try:
    mat_data = sio.loadmat(r"~~\data_erf.mat")
    sol = mat_data['sol'] 
    func_f1 = mat_data['func_f'] 
    func_f = mat_data['func_f1'] 
    A = mat_data['A'] 
    
    sol_tensor = torch.from_numpy(sol).type(torch.float32)
    func_f_tensor = torch.from_numpy(func_f).type(torch.float32)
    func_f1_tensor = torch.from_numpy(func_f1).type(torch.float32)
    At = torch.from_numpy(A).type(torch.float32)
    
    s = sol.shape[1]
    x_grid = torch.linspace(0, 1, s).unsqueeze(0).unsqueeze(0)
    num_time_steps = 2
    random_seed = 12 
    np.random.seed(random_seed)
    n_total = sol_tensor.shape[0]
    indices = np.arange(n_total)
    np.random.shuffle(indices)
    
    n_train = 800 #int(n_total * 0.8)
    n_val = 100 #int(n_total * 0.1)
    n_test = n_total - n_train - n_val
    
    train_indices = indices[:n_train]
    val_indices = indices[n_train:n_train + n_val]
    test_indices = indices[n_train + n_val:]
    
    y_train = sol_tensor[train_indices]
    y_val = sol_tensor[val_indices]
    y_test = sol_tensor[test_indices]
    
    X_train = func_f1_tensor[train_indices]
    X_val = func_f1_tensor[val_indices]
    X_test = func_f1_tensor[test_indices]

    load_train = func_f_tensor[train_indices]
    load_val = func_f_tensor[val_indices]
    load_test = func_f_tensor[test_indices]

    
except FileNotFoundError:
    print("오류: 'data_erf.mat' 파일을 찾을 수 없습니다. 파일 경로를 다시 확인해주세요.")
    exit()


class CNN1D_Model(nn.Module):
    def __init__(self):
        super(CNN1D_Model, self).__init__()
        
        self.conv1 = nn.Conv1d(in_channels=2, out_channels=16, kernel_size=7, stride=1, padding=3)
        self.bn1 = nn.BatchNorm1d(16)
        self.gelu1 = nn.GELU() 
        self.pool1 = nn.MaxPool1d(2)
        
        self.conv2 = nn.Conv1d(16, 32, kernel_size=7, stride=1, padding=3)
        self.bn2 = nn.BatchNorm1d(32)
        self.gelu2 = nn.GELU()
        self.pool2 = nn.MaxPool1d(2)
        
        self.conv3 = nn.Conv1d(32, 64, kernel_size=7, stride=1, padding=3)
        self.bn3 = nn.BatchNorm1d(64)
        self.gelu3 = nn.GELU()
        self.pool3 = nn.MaxPool1d(2)

        self.conv4 = nn.Conv1d(64, 128, kernel_size=7, stride=1, padding=3)
        self.bn4 = nn.BatchNorm1d(128)
        self.gelu4 = nn.GELU()
        self.pool4 = nn.MaxPool1d(2)

        self.conv5 = nn.Conv1d(128, 256, kernel_size=7, stride=1, padding=3)
        self.bn5 = nn.BatchNorm1d(256)
        self.gelu5 = nn.GELU()
        self.pool5 = nn.MaxPool1d(2)

        self.fc1 = nn.Linear(256 * (s // 32), 128) 
        self.gelu6 = nn.GELU()
        self.fc2 = nn.Linear(128, 64)
        self.gelu7 = nn.GELU()
        self.fc3 = nn.Linear(64, s)
        
        self.proj = nn.Linear(2 * s, s)

        self._initialize_weights_he()
    
    def _initialize_weights_he(self):
        """He 초기화 (kaiming_normal)를 사용하여 모델 가중치를 초기화합니다."""
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                if m.bias is not None:
                    init.zeros_(m.bias)
    
    def forward(self, x):
        x = self.pool1(self.gelu1(self.bn1(self.conv1(x))))
        x = self.pool2(self.gelu2(self.bn2(self.conv2(x))))
        x = self.pool3(self.gelu3(self.bn3(self.conv3(x))))
        x = self.pool4(self.gelu4(self.bn4(self.conv4(x))))
        x = self.pool5(self.gelu5(self.bn5(self.conv5(x))))

        x = x.view(x.size(0), -1)
        
        x = self.gelu6(self.fc1(x))
        x = self.gelu7(self.fc2(x))
        x = self.fc3(x)

        out = x 
        return out
model = CNN1D_Model()
optimizer = optim.LBFGS(model.parameters(), lr=.1, max_iter=100, history_size=100)

best_val_loss = float('inf')
best_model_path = 'best_pde_cnn_model_int.pth'
patience_counter = 0 
patience_limit = 100 
num_epochs = 730

loss_fn = nn.MSELoss()
for epoch in range(num_epochs):
    def closure(is_training=True):
        optimizer.zero_grad()
        f_true = X_train if is_training else X_val
        load_true = load_train if is_training else load_val
        y_true = y_train if is_training else y_val
        total_loss = 0
        batch_size = f_true.size(0)
        x_tensor = x_grid.repeat(batch_size, 1, 1)
        f_tensor= f_true.unsqueeze(1)
        load_tensor= load_true.unsqueeze(1)
        f_result = f_tensor

        model_input = torch.cat([x_tensor, load_tensor], dim=1)
        outputs = model(model_input)

        # term1_train = outputs[:, 0].unsqueeze(1) * torch.exp(-x_grid / torch.sqrt(torch.tensor(eps)))
        # term2_train = outputs[:, -1].unsqueeze(1) * torch.exp(-(1 - x_grid) / torch.sqrt(torch.tensor(eps)))
        # term2_train = outputs[:, -1].unsqueeze(1) * torch.erf(torch.sqrt(1 / (2 * eps)) * x_grid)
        predicted_u_next = outputs
        predicted_u_next=predicted_u_next.squeeze()
        f_result_squeezed = load_tensor.squeeze(1)
        At_batch = At.unsqueeze(0).repeat(predicted_u_next.size(0), 1, 1)
        predicted_u_next_col = predicted_u_next.unsqueeze(2)
        result = torch.bmm(At_batch, predicted_u_next_col).squeeze(2)
        residual = result - f_result_squeezed
        loss_t =  torch.mean(torch.abs(residual)**2)*1000000
        total_loss += loss_t
        # total_loss += torch.mean(torch.abs(predicted_u_next- y_true)**2)*10000#loss_fn(predicted_u_next.squeeze(), y_true)
        # rel_l2_errors = torch.linalg.norm(predicted_u_next - y_true, dim=1) / torch.linalg.norm(y_true, dim=1)
        # total_loss += torch.mean(rel_l2_errors)
        if is_training:
            total_loss.backward()
        return total_loss
    
    model.train()
    optimizer.step(closure)
    train_loss = closure().item()
    
    model.eval()
    with torch.no_grad():
        val_loss = closure(is_training=False).item()
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_path)
        patience_counter = 0
        print(f'에포크 {epoch+1}: 검증 손실 개선. {val_loss:.6f}로 감소. 모델 저장.')
    else:
        patience_counter += 1
        
    print(f'에포크 [{epoch+1}/{num_epochs}], 훈련 손실: {train_loss:.6f}, 검증 손실: {val_loss:.6f}')
    
    if patience_counter >= patience_limit:
        break


model.load_state_dict(torch.load(best_model_path))
model.eval()
with torch.no_grad():
    current_u = X_test.unsqueeze(1)
    total_test_loss = 0
    batch_size = current_u.size(0)
    x_tensor = x_grid.repeat(batch_size, 1, 1)

    predicted_u_next_list = []

    f_current_time = load_test.unsqueeze(1)
    
    f_result = f_current_time
    model_input = torch.cat([x_tensor, f_result], dim=1)
    start_time = time.time()
    outputs_test = model(model_input)
    end_time = time.time()
    elapsed_time = end_time - start_time

    print(f'코드 실행 시간: {elapsed_time:.4f}초')
    # term1_test = outputs_test[:, 0].unsqueeze(1) * torch.exp(-x_grid / torch.sqrt(torch.tensor(eps)))
    # term2_test = outputs_test[:, -1].unsqueeze(1) * torch.exp(-(1 - x_grid) / torch.sqrt(torch.tensor(eps)))
    term2_test = outputs_test[:, -1].unsqueeze(1) * torch.erf(torch.sqrt(1 / (2 * eps)) * x_grid)

    predicted_u_next = outputs_test -  term2_test
    
    predicted_u_next = predicted_u_next.squeeze()
    
    row_to_plot = predicted_u_next[77, :].detach().numpy()
    plt.figure(figsize=(10, 6))
    plt.plot(row_to_plot)
    plt.title('Plot of the 3rd Row of the outputs Array')
    plt.xlabel('Column Index')
    plt.ylabel('Value')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('row_3_plot.png')
    
    current_u = predicted_u_next.unsqueeze(1)
    predicted_u_next_list.append(predicted_u_next.cpu().numpy())

all_predicted_u_next = np.stack(predicted_u_next_list, axis=0)
y_test_np = y_test.cpu().numpy()
file_name = 'prediction_and_ytest.mat'
scipy.io.savemat(file_name, {'predicted_u': all_predicted_u_next, 'y_test': y_test_np})

print(f'\n결과가 {file_name} 파일로 저장되었습니다.')