import os
import sys
import torch
import torch.nn as nn
import math
import random
import time
from config import *
from system import *

# Set arguments
args = get_config()

# CUDA support
if torch.cuda.is_available() and args.device == 'cuda':
   device = torch.device('cuda')
else:
    device = torch.device('cpu')


# PINN module to generate the solution space.
class genPINN(nn.Module):
    def __init__(self, nn_architecture, spatial_dim, temporal_dim):
        super(genPINN, self).__init__()
        
        model_width = len(nn_architecture)
        
        self.start_layer = nn.Linear(spatial_dim + temporal_dim, nn_architecture[0])
        self.adaptive_layers = nn.ModuleList([
            nn.Linear(nn_architecture[i-1], nn_architecture[i]) for i in range(1, model_width)
        ])
        self.end_layer = nn.Linear(nn_architecture[model_width-1], 1)
        
        self.tanh = nn.Tanh()
        
    def forward(self, x, t):
        
        x = x.float()
        t = t.float()
        
        input = torch.cat([x, t], dim=1)
        output = self.start_layer(input)
        output = self.tanh(output)
        
        for layers in self.adaptive_layers:
            output = layers(output)
            output = self.tanh(output)
        
        output = self.end_layer(output)
        output = output.squeeze(dim=1)
        
        return output


# Load an appropriate model.
def load_random_PINN(xdim, tdim):
    
    net = 0
            
    ########## Random Architecture PINN for 1D cdr equation #########
    DEPTH = random.randint(10, 15)
    MIN_WIDTH = 30
    MAX_WIDTH = 60
    random_arch = [random.randint(MIN_WIDTH, MAX_WIDTH) for _ in range(DEPTH)]
    net = genPINN(random_arch, xdim, tdim)
    print("=================================================")
    print("Architecture:", random_arch)
    print("=================================================")
    ##################################################################

    return net