import os
import yaml
import torch
import argparse
import numpy as np
from pathlib import Path
from src.model import create_mamba_model
from src.eval import test_model
from src.data.dataload import load_and_split_data
from src.data.datagen import *


def parse_args():
    """Parses command line arguments.

    Returns:
        argparse.Namespace: Parsed command line arguments containing:
            - config: Path to the configuration file
            - checkpoint_dir: Directory containing model checkpoints
            - test_gpu: GPU ID to use for testing (default: 0)
            - data_path: Optional path to existing dataset
    """
    parser = argparse.ArgumentParser(description='Test the model')
    parser.add_argument('--config', type=str, required=True, help='Path to the config file')
    parser.add_argument('--checkpoint_dir', type=str, required=True, help='Directory containing model checkpoints')
    parser.add_argument('--test_gpu', type=int, default=0, help='GPU ID to use for testing')
    parser.add_argument('--data_path', type=str, default=None, help='Path to existing dataset (optional)')
    return parser.parse_args()

def load_config(config_path):
    """Loads configuration from a YAML file.

    Args:
        config_path (str): Path to the YAML configuration file.

    Returns:
        dict: Configuration parameters loaded from the YAML file.
    """
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config

def load_existing_data(data_path):
    """Loads .npy data files from specified directory.

    Args:
        data_path (str): Directory path containing the data files.

    Returns:
        tuple: A pair of torch.Tensor containing:
            - data: The loaded data tensor
            - timestamps: The corresponding timestamps tensor

    Raises:
        FileNotFoundError: If required data files are not found in the specified directory.
    """
    print(f"Loading data from {data_path}...")
    data_file = os.path.join(data_path, 'data.npy')
    timestamps_file = os.path.join(data_path, 'timestamps.npy')
    if not all(os.path.exists(f) for f in [data_file, timestamps_file]):
        raise FileNotFoundError(f"Required data files (data.npy, timestamps.npy) not found in {data_path}")
    data = torch.from_numpy(np.load(data_file))
    timestamps = torch.from_numpy(np.load(timestamps_file))
    print(f"Successfully loaded data:")
    print(f"- data shape: {data.shape}")
    print(f"- timestamps shape: {timestamps.shape}")
    return data, timestamps

def generate_or_load_data(config, data_path=None):
    """Generates new data or loads existing data based on configuration.

    Args:
        config (dict): Configuration dictionary containing data generation parameters.
        data_path (str, optional): Path to existing data directory. Defaults to None.

    Returns:
        tuple: A pair of torch.Tensor containing:
            - data: Generated or loaded data tensor
            - timestamps: Corresponding timestamps tensor

    Raises:
        ValueError: If specified data type is not supported.
    """
    if data_path is not None:
        return load_existing_data(data_path)
    
    model_name = config["save"]["model_name"]
    data_type = config["data"]["type"]
    save_dir = os.path.join(config["data"]["save_dir"], f"{model_name}_{data_type}")
    data_file = os.path.join(save_dir, 'test_data.pt')
    if os.path.exists(data_file):
        print(f"Loading existing test data from {data_file}...")
        saved_data = torch.load(data_file)
        return saved_data['data'], saved_data['timestamps']
    
    print(f"Generating new {data_type} test data...")
    total_len = config["input_len"] + config["output_len"]
    test_size = int(config["data"]["num_samples"] * (1 - config["train_ratio"] - config["val_ratio"]))
    data_generators = {
        "gbm": generate_gbm_data_multi_two,
        "sde": generate_sde_data,
        "gbm_multi": generate_gbm_data_multi,
        "two_phase_gbm": generate_two_phase_gbm,
        "generate_gbm_data_three": generate_gbm_data_three,
        "generate_three_peak_switching_sde_data": generate_three_peak_switching_sde_data,
        "generate_hybrid_gbm_ou_data": generate_hybrid_gbm_ou_data,
        "generate_hybrid_gbm_cir_data": generate_hybrid_gbm_cir_data,
        "generate_hybrid_gbm_ou_cir_data": generate_hybrid_gbm_ou_cir_data,
        "generate_hybrid_ou_gbm_data": generate_hybrid_ou_gbm_data,
        "generate_hybrid_ou_cir_data": generate_hybrid_ou_cir_data,
        "generate_hybrid_ou_gbm_cir_data": generate_hybrid_ou_gbm_cir_data,
        "generate_hybrid_cir_ou_gbm_data": generate_hybrid_cir_ou_gbm_data,
        "generate_hybrid_cir_ou_data": generate_hybrid_cir_ou_data,
        "generate_hybrid_cir_gbm_data": generate_hybrid_cir_gbm_data
    }
    if data_type not in data_generators:
        raise ValueError(f"Unsupported data type: {data_type}")
        
    data, timestamps = data_generators[data_type](
        num_samples=test_size,
        seq_len=total_len
    )
    os.makedirs(save_dir, exist_ok=True)
    torch.save({
        'data': data,
        'timestamps': timestamps
    }, data_file)
    return data, timestamps

def create_test_data_loader(test_data, test_timestamps, input_len, output_len, d_inner, batch_size):
    """Creates a DataLoader specifically for test data.

    Args:
        test_data (torch.Tensor): Test data tensor.
        test_timestamps (torch.Tensor): Test timestamps tensor.
        input_len (int): Length of input sequence.
        output_len (int): Length of output sequence.
        d_inner (int): Inner dimension size.
        batch_size (int): Batch size for the DataLoader.

    Returns:
        torch.utils.data.DataLoader: DataLoader for the test dataset.
    """
    from src.data.dataset import TimeSeriesDataset
    test_dataset = TimeSeriesDataset(test_data, test_timestamps, input_len, output_len, d_inner)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return test_loader

def main():
    """Main function for model testing.
    
    Steps:
    1. Parse command line arguments
    2. Load configuration
    3. Create or load the model
    4. Generate or load test data
    5. Create test data loader
    6. Run model testing
    7. Save evaluation results
    """
    args = parse_args()
    config = load_config(args.config)
    config['test_gpu'] = args.test_gpu
    config['model_save_path'] = args.checkpoint_dir
    model = create_mamba_model(
        input_dim=config["input_dim"],
        output_dim=config["output_dim"],
        input_len=config["input_len"],
        output_len=config["output_len"],
        d_model=config["d_model"],
        n_layer=config["n_layer"]
    )
    d_inner = model.args.d_inner
    data, timestamps = generate_or_load_data(config, args.data_path)
    if args.data_path:
        test_data, test_timestamps = data, timestamps
    else:
        _, _, test_data, _, _, test_timestamps = load_and_split_data(
            data, timestamps, config["train_ratio"], config["val_ratio"]
        )
    test_loader = create_test_data_loader(
        test_data, test_timestamps,
        config["input_len"], config["output_len"], 
        d_inner, config["batch_size"]
    )
    test_model(model, test_loader, config)
    print(f"Test finished. Results saved to {os.path.join(args.checkpoint_dir, 'eval_results')}")

if __name__ == "__main__":
    main()
