import torch
from utils.YParams import YParams
from models import fno
from torchinfo import summary
from pdb import set_trace as bp

# params = YParams('./config/operators_poisson.yaml', 'poisson-scale-k1_5')
# params = YParams('./config/operators_poisson.yaml', 'poisson-scale-k1_5-k2.5_7.5')
# params = YParams('./config/operators_poisson.yaml', 'poisson-scale-k1_5-k10_20')
# params = YParams('./config/operators_poisson.yaml', 'poisson-scale-k1_5-k2.5_7.5-demo_7-baseline')
# params = YParams('./config/operators_poisson.yaml', 'poisson-scale-k1_5-k2.5_7.5-demo_7')
params = YParams('./config/operators_poisson.yaml', 'poisson-scale-k1_5-k10_20-spatial-demo_7')
print(params.__dict__)

# params.embed_cut = 16
# params.mode_cut = 4

def count_parameters(model):
    params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return params/1000000

model = fno.fno(params)
model = model.eval()
n_params = count_parameters(model)
print('number of model parameters: {} M'.format(n_params))

# partial = torch.load("expts/poisson-scale-k1_5/test/checkpoints/ckpt_best.tar")['model_state']#, map_location=lambda storage, loc: storage)
# partial = torch.load("expts/poisson-scale-k1_5-k10_20-demo_7/bs16_lr1.25e-4_sub100_attn.XY.FC1_aug_16_v2/checkpoints/ckpt_best.tar")['model_state']#, map_location=lambda storage, loc: storage)
# state = model.backbone.state_dict()
# # 1. filter out unnecessary keys
# pretrained_dict = {k: v for k, v in partial.items() if k in state and state[k].size() == partial[k].size()}
# # 2. overwrite entries in the existing state dict
# state.update(pretrained_dict)
# # 3. load the new state dict
# message = model.backbone.load_state_dict(state)

# message = model.load_state_dict(partial)

# print(message)

# x = torch.load("input.pth")
# _ = model(x)

# summary(model, input_size=(1, 4, 128, 128))
# summary(model, input_size=(1, 39, 128, 128))

# summary(model, input_size=(1, 4, 64, 64)) # baseline
# summary(model, input_size=(1, 7*(4+1)+4, 64, 64))
summary(model, input_size=(1, 7*(64+1)+64, 64, 64)) # FNO feature input
