import torch
from einops import repeat

# from models.sequence.ss.s4_simple.network import SimpleS4Model
from models.spacetime import get_encoder, get_decoder
from models.optimizer import get_optimizer, get_scheduler


# def load_model_from_config(n_layers, config, s4_version='discrete', 
#                            c_bias=False, debug_shift=False,
#                            encoder_decoder=True, verbose=True):
#     config.model.layer.update({
#         'nHippos': config.model.layer['n_hippos']
#     })
    
#     model = SimpleS4Model(n_layers, s4_version, c_bias, **config.model.layer)  # debug_shift, 
    
#     if encoder_decoder:
#         model.initialize(
#             get_encoder(input_dim=1, n_hippos=config.model.layer['n_hippos'], 
#                         n_layers=1, activation='none',
#                         initialize_identity=True, no_grad=True), # New
#             get_decoder(output_dim=1, n_hippos=config.model.layer['n_hippos'], 
#                         n_layers=1, activation='none')
#         )
#     else:
#         model.initialize(
#             torch.nn.Identity(),
#             torch.nn.Identity()
#         )
    
#     try:
#         optimizer = get_optimizer(model, config.optimizer)
#         scheduler = get_scheduler(model, optimizer, config.scheduler)
#     except:
#         optimizer = None; scheduler = None
    
#     if verbose:
#         print(f'-----')
#         print(f'Arch:')
#         print(f'-----')
#         print(model)

#         print(f'---------------------')
#         print(f'Trainable parameters:')
#         print(f'---------------------')
#         for n, p in model.named_parameters():
#             if p.requires_grad:
#                 print(n)
                
#     return model, optimizer, scheduler



# def load_best_model(_config, args, s4_version='discrete', c_shift=None, c_bias=False,
#                     shift=True, encoder_decoder=True, split='val'):
#     best_model = SimpleS4Model(n_layers=_config.model['n_layers'],
#                                s4_version=s4_version,
#                                c_bias=c_bias,                             
#                                **_config.model.layer)
#     # Hack for now - Initialize learnable c_shift term
#     if c_shift is not None:
#         for ix in range(len(best_model.nn)):
#             best_model.nn[ix].ssm.initialize_learnable_shift_weights(
#                 c_shift_parameters=c_shift,
#                 c_train=c_bias
#             )

#     if encoder_decoder:
#         best_model.initialize(
#             get_encoder(input_dim=1, n_hippos=_config.model.layer['n_hippos'], 
#                         n_layers=1, activation='none',
#                         initialize_identity=True, no_grad=True), # New
#             get_decoder(output_dim=1, n_hippos=_config.model.layer['n_hippos'], 
#                         n_layers=1, activation='none')
#         )
#     else:
#         best_model.initialize(
#             torch.nn.Identity(),
#             torch.nn.Identity()
#         )
    
#     best_model_path = args.best_val_checkpoint_path if split == 'val' else args.best_train_checkpoint_path
#     print(f'-> Loading best model from {best_model_path}')
#     best_model_dict = torch.load(best_model_path)
    
#     best_model.load_state_dict(best_model_dict['state_dict'])

#     for k, v in best_model_dict.items():
#         if 'state_dict' not in k and 'weight' not in k:
#             print(f'{k}: {v}')
    
#     return best_model



# def initialize_shift_weights(model, args, dataset=None):
#     for lix in range(len(model.nn)):
#         if args.ground_truth_b:
#             b_weights = torch.zeros(args.d_state).float()
#             b_weights[0] = 1.
#             b_weights = repeat(b_weights, 'd -> c h d', 
#                                c=args.channels, h=args.num_shift_kernel)
#         else:
#             b_weights = None
        
#         if args.ground_truth_c and dataset is not None:
#             factor = 1 if args.replicate in [2, 6] else -1
#             c_weights = torch.from_numpy(dataset.process.ar_params[1:] * factor)
#             c_weights = repeat(c_weights, 'd -> c h d', 
#                                c=args.channels, h=args.num_shift_kernel)
#         elif args.c_bias:
#             # c_weights = None
#             c_weights = torch.randn(args.channels,
#                                     args.num_shift_kernel,
#                                     args.d_state + 1)
#         else:
#             c_weights = None   # c_shift = torch.randn(*_fp_shift)

#         if args.ground_truth_p:
#             p_weights = torch.zeros(args.d_state)
#             p_weights = repeat(p_weights, 'd -> c h d', 
#                                c=args.channels, h=args.num_shift_kernel)
#         else:
#             p_weights = None
#         model.nn[lix].ssm.initialize_learnable_shift_weights(b_shift_parameters=b_weights,
#                                                              b_train=args.learn_b,
#                                                              c_shift_parameters=c_weights, 
#                                                              c_train=args.learn_c,
#                                                              p_companion_parameters=p_weights,
#                                                              p_train=args.learn_p)
    