import argparse
import torch
from models import MODEL_REGISTRY  # Import model registry
from utils import *  # Import all utilities
from train_baseline import process_traj, get_data, test_rollout
import matplotlib.pyplot as plt

def custum_mse_2(pred, target):
    pred = pred.reshape(-1,target.shape[-1])
    target = target.reshape(-1,target.shape[-1])
    return torch.mean((pred - target)**2)

def custum_mse(pred, target):
    return torch.norm(pred - target, 2, dim=tuple(range(1, target.dim()))).mean()/125



def main(config: Config):
    device = config.device
    
    path_weights =""

    model_class = MODEL_REGISTRY[config.model]
    model = model_class(
        **{k: v for k, v in config.__dict__.items() if k in model_class.__init__.__code__.co_varnames}
    ).to(device)

    # Load the model
    model.load_state_dict(torch.load(path_weights))

    valid_path = config.valid_path
    valid_data = get_data(valid_path)
    valid_traj = process_traj(valid_data,time_jump=config.time_jump)
    valid_traj = valid_traj.to(device)

    loss_traj_k = []
    loss_fn = custum_mse
    k_values = list(range(1, 33))
    
    for k in k_values:
        pred_traj = test_rollout(model, valid_traj, num_steps=k, residual=config.residual, 
                            norm_res=config.norm_res)
        loss_traj_k.append(loss_fn(pred_traj[:,1:,:], valid_traj[:,1:,:]).item())

    print(config.name)
    print(config.hidden_channels)
    print(loss_traj_k)



if __name__ == '__main__':
    # Initialize argument parser for command line interface
    parser = argparse.ArgumentParser()
    # Required: Path to the YAML configuration fil
    parser.add_argument('--name', type=str, help='Model name.')
    parser.add_argument('--run_id', type=str, help='Run id of the experiment.')
    
    # Parse command line arguments
    args = parser.parse_args()

    
    # Load configuration from YAML file
    config = Config.from_yaml("")
    # Override configuration values with command line arguments if provided
    config.run_id = args.run_id
    config.name = args.name
    # Start the training process with the configured settings
    main(config)