
from libs import *
from libs.fno import FourierNeuralOperator

get_system()
get_seed(1127802)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


ADD_GRAD_CHANNEL = True
N_CHANNEL = 1
PARTS = [4,5,6]

valid_dataset = EITDataset(part_idx=PARTS,
                           file_type='h5',
                           channel=N_CHANNEL,
                           return_grad=ADD_GRAD_CHANNEL,
                           subsample=2,
                           train_data=False,
                           online_grad=False
                           )

valid_loader = DataLoader(valid_dataset,
                          batch_size=20,
                          shuffle=False,
                          drop_last=False,
                          pin_memory=True)

sample = next(iter(valid_loader))
phi = sample['phi']
gradphi = sample['gradphi']
targets = sample['targets']
print(phi.size(), gradphi.size(), targets.size())

config = dict(in_dim=N_CHANNEL*(1+2*ADD_GRAD_CHANNEL), # input channel
        n_hidden=48, # dmodel of the input for spectral conv
        freq_dim=48,  # number of frequency features
        out_dim=1,
        modes=14,  # number of fourier modes
        num_spectral_layers=6,
        n_grid=101, 
        dim_feedforward=None,
        spacial_dim=2,
        spacial_fc=True,
        return_freq=True,  # to be consistent with trainer
        normalizer=None,
        activation='silu',
        last_activation=False,
        add_grad_channel=True,
        dropout=0.05,
        debug=False,
        )

epochs = 50

lr = 1e-3
h = 1/201

model = FourierNeuralOperator(**config)
model.load_state_dict(torch.load(os.path.join(MODEL_PATH, 'fno2d-base.pt')))
model.to(device);


metric_funcs = {"cross entropy": CrossEntropyLoss2d(regularizer=False, h=h),
                 "relative L2": L2Loss2d(regularizer=False, h=h),
                 "dice": SoftDiceLoss()}
val_result = validate_epoch_eit(model, metric_funcs, valid_loader, device)
print(val_result)

"""
FNO2d baseline 1 channel + explicit grad
trained grid size: (201, 201)

11m model 4 layers spatial hidden=freq hidden=48, 14 modes
best validation cross entropy: 1.796e-01 at epoch 20

test cross entropy on (101, 101): 3.913e-01.
"""

