import argparse
import os
import torch
import json

from models import MODEL_REGISTRY
from train import *
from utils import *

def load_model(model_path, config, device):
    """Load trained model and scaler."""
    # Initialize model
    model_class = MODEL_REGISTRY[config.model]
    model = model_class(
        **{k: v for k, v in config.__dict__.items() if k in model_class.__init__.__code__.co_varnames}
    ).to(device)
    
    # Load model state
    model.load_state_dict(torch.load(os.path.join(model_path, 'model.pt')))
    
    # Load scaler if it exists
    scaler = None
    scaler_path = os.path.join(model_path, 'scaler.json')
    if os.path.exists(scaler_path):
        scaler = MaxScaler()
        scaler.load(scaler_path)
    
    return model, scaler

def valid_onestep(dataloader, model, loss_fn):
    model.eval()
    with torch.inference_mode():
        loss = 0
        for x, y in dataloader:
            pred = model(x)
            loss += loss_fn(pred, y).item()
        loss /= len(dataloader)
    #print(f'valid/onestep_loss: {loss:.15f}')
    pass

def valid_rollout(traj, model, loss_fn, residual, model_path,scaler=None):
    model.eval()
    with torch.inference_mode():
        pred = torch.zeros_like(traj).to(traj.device)
        pred[:, 0:1, :] = traj[:, 0:1, :]
        for i in range(1, traj.shape[1]):
            x = pred[:, i-1:i, :].clone()
            pred_step = model(x)
            if scaler is not None:
                pred_step = scaler.inverse_transform(pred_step)
            if not residual:
                pred[:, i:i+1, :] = pred_step
            else:
                pred[:, i:i+1, :] = pred[:, i-1:i, :] + pred_step*0.3
        loss = loss_fn(pred[:, 1:, :], traj[:, 1:, :]).item()
    #print(f'valid/rollout_loss: {loss:.15f}')

           ## Compute correlationf
    pred_flat = pred.reshape(-1, pred.shape[-1])
    traj_flat = traj.reshape(-1, traj.shape[-1])

    corr_flat = batched_corrcoef(pred_flat, traj_flat)
    corr = corr_flat.reshape(traj.shape[0], traj.shape[1])

    ## Find threshold indices
    rollout_corr_08 = find_corr_threshold_index(corr, 0.8).mean() * 0.2 * 16
    rollout_corr_09 = find_corr_threshold_index(corr, 0.9).mean() * 0.2 * 16


    #print(f'valid/rollout_corr_08: {rollout_corr_08:.7f}, valid/rollout_corr_09: {rollout_corr_09:.7f}')

    extracted_name = os.path.basename(model_path)


    return rollout_corr_08, rollout_corr_09,loss


def main(config: str, model_path: str):

    device = f'cuda:{config.device}' if torch.cuda.is_available() else 'cpu'
    print(f'Using {device}.')

    valid_path = config.test_path
    valid_data = get_data(valid_path)
    valid_data_input, valid_data_output = process_data(valid_data, time_jump=config.time_jump,residual=config.residual, norm_res=config.norm_res)

    model, scaler = load_model(model_path, config, device)
    if config.scale:
        valid_data_output = scaler.transform(valid_data_output)

    valid_data_input, valid_data_output = valid_data_input.to(device), valid_data_output.to(device)
    valid_dataloader = torch.utils.data.DataLoader(Dataset(valid_data_input, valid_data_output), 
                                                     batch_size=config.batch_size, 
                                                     shuffle=False)

    valid_traj = process_traj(valid_data,time_jump=config.time_jump)
    valid_traj = valid_traj.to(device)
    
    loss_fn = torch.nn.MSELoss()



    c_08, c_09,loss_t = valid_rollout(valid_traj, model, loss_fn, config.residual, model_path,scaler)
    print(f"Corr 08: {c_08}")
    print(f"Corr 09: {c_09}")
    print(f"Loss: {loss_t}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str)
    parser.add_argument('--device', type=int)
    parser.add_argument('--model_path', type=str)
    args = parser.parse_args()

    config = Config.from_yaml(args.config)
    if args.device is not None:
        config.device = args.device
    
    main(config, args.model_path)