# convert_pfg_data_pt
import glob
import numpy as np
import torch
from pathlib import Path
from tqdm import tqdm
import argparse

def convert_pfg_data(data_dirs, out_file):
    """
    Converts PFG datasets from multiple .npz files located in one or more
    directories into a single .pt file.

    The .npz files are expected to be generated by the
    prepare_agent_centric_dataset_pfg_only function.

    Args:
        data_dirs (list of str or Path): List of directories containing the .npz PFG dataset files.
        out_file (str or Path): Path to save the output .pt file.
    """
    out_file = Path(out_file)

    all_spatial_inputs = []
    all_nonspatial_inputs = []
    all_target_potentials = []
    # Potentially, you might also want to save expert actions
    # all_expert_actions = []

    # --- MODIFICATION START 1 ---
    # Collect all .npz file paths from all provided directories first
    all_npz_paths = []
    for data_dir in data_dirs:
        data_dir = Path(data_dir)
        print(f"Searching for files in: {data_dir}")
        # Look for files ending with _pfg_only.npz
        npz_paths_in_dir = sorted(data_dir.glob("*_pfg_only.npz"))
        all_npz_paths.extend(npz_paths_in_dir)
    # --- MODIFICATION END 1 ---

    if not all_npz_paths:
        print(f"No '*_pfg_only.npz' files found in any of the specified directories: {data_dirs}")
        return

    print(f"\nFound a total of {len(all_npz_paths)} .npz files to process from all directories.")

    for npz_file_path in tqdm(all_npz_paths, desc="Processing PFG .npz files"):
        try:
            data = np.load(npz_file_path)

            # Load PFG related data
            # (N, C_spatial, H, W) - N is number of samples in this file
            pfg_spatial = data["pfg_spatial_input_tensors"]
            # (N,)
            pfg_target_vec_y = data["pfg_nonspatial_target_vec_y"]
            # (N,)
            pfg_target_vec_x = data["pfg_nonspatial_target_vec_x"]
            # (N, H, W)
            pfg_targets = data["pfg_target_potentials"]
            # expert_actions = data["expert_actions"] # (N,) - Optional

            num_samples_in_file = pfg_spatial.shape[0]

            for i in range(num_samples_in_file):
                # Spatial input: (C_spatial, H, W)
                sample_spatial = pfg_spatial[i].astype(np.float32)

                # Non-spatial input: (2,) - previous code had a comment for 3, but implementation was 2
                # Order: target_vec_y, target_vec_x
                sample_nonspatial = np.array([
                    pfg_target_vec_y[i],
                    pfg_target_vec_x[i]
                ], dtype=np.float32)

                # Target potential: (H, W) -> add channel dim -> (1, H, W)
                sample_target = pfg_targets[i].astype(np.float32)[None, :, :]

                all_spatial_inputs.append(torch.from_numpy(sample_spatial))
                all_nonspatial_inputs.append(torch.from_numpy(sample_nonspatial))
                all_target_potentials.append(torch.from_numpy(sample_target))
                # if you want to save actions:
                # all_expert_actions.append(torch.tensor(expert_actions[i], dtype=torch.long))

        except KeyError as e:
            print(f"Warning: KeyError {e} in file {npz_file_path}. Skipping this file or sample.")
            continue
        except Exception as e:
            print(f"Warning: Error processing file {npz_file_path}: {e}. Skipping this file.")
            continue

    if not all_spatial_inputs:
        print("No samples were processed. Output file will not be created.")
        return

    # Stack all samples into large tensors
    # X_spatial: (Total_N, C_spatial, H, W)
    X_spatial = torch.stack(all_spatial_inputs)
    # X_nonspatial: (Total_N, 2)
    X_nonspatial = torch.stack(all_nonspatial_inputs)
    # Y_target: (Total_N, 1, H, W)
    Y_target = torch.stack(all_target_potentials)
    # if saving actions:
    # Expert_actions_tensor = torch.stack(all_expert_actions) # (Total_N,)

    total_samples = X_spatial.size(0)
    print(f"\nSuccessfully processed all files.")
    print(f"Total samples collected: {total_samples}")
    print(f"  Spatial input tensor shape: {X_spatial.shape}")
    print(f"  Non-spatial input tensor shape: {X_nonspatial.shape}")
    print(f"  Target potential tensor shape: {Y_target.shape}")
    # if saving actions:
    # print(f"  Expert actions tensor shape: {Expert_actions_tensor.shape}")


    # Save the tensors to a .pt file
    # The format matches the reference: a tuple of (X_spatial, X_nonspatial, Y_target)
    torch.save((X_spatial, X_nonspatial, Y_target), out_file)
    # If you also want to save actions, you might change the saved object, e.g.:
    # torch.save({
    # 'spatial': X_spatial,
    # 'non_spatial': X_nonspatial,
    # 'target_potential': Y_target,
    # 'expert_action': Expert_actions_tensor
    # }, out_file)
    print(f"Saved combined preprocessed PFG dataset to {out_file}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Convert PFG .npz dataset from multiple directories to a single .pt file.")
    # --- MODIFICATION START 2 ---
    # Change argument to accept one or more values
    parser.add_argument(
        "--data_dirs",  # Renamed for clarity (plural)
        type=str,
        nargs='+',      # This is the key change: allows multiple arguments
        required=True,  # Make it required instead of having a default
        help="One or more directories containing the input .npz PFG dataset files."
    )
    # --- MODIFICATION END 2 ---
    parser.add_argument(
        "--out_file",
        type=str,
        default="pfg_dataset_combined.pt",
        help="Path to save the output .pt file."
    )
    args = parser.parse_args()

    # --- MODIFICATION START 3 ---
    # Pass the list of directories to the function
    convert_pfg_data(args.data_dirs, args.out_file)
    # --- MODIFICATION END 3 ---