import argparse
import json
import numpy as np
import pandas as pd  # Add pandas import for CSV handling

from torchvision import transforms

from data_loading.data_loading import get_data
from sklearn.model_selection import train_test_split

from src.utils import mkdir


def main(args):
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),
        ]
    )

    train_data, _ = get_data(args, transform)
    # concatenate train and test data
    X, y = train_data.tensors[0].numpy(), train_data.tensors[1].numpy()

    # Flatten the data

    X = np.squeeze(X)
    y = y[:, :, 0].reshape(-1).astype(int)

    print(f"X shape: {X.shape}")
    print(f"y shape: {y.shape}")
    print(f"X samples: {X[:5]}")
    print(f"y samples: {y[:5]}")

    X_temp, X_test, y_temp, y_test = train_test_split(
        X, y, test_size=args.test_split, random_state=42
    )

    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp, test_size=args.validation_split, random_state=42
    )

    mkdir(args.result_dir)
    # Save the datasets as .npy files
    np.save(f"{args.result_dir}/X_num_train.npy", X_train)
    np.save(f"{args.result_dir}/y_train.npy", y_train)
    np.save(f"{args.result_dir}/X_num_val.npy", X_val)
    np.save(f"{args.result_dir}/y_val.npy", y_val)
    np.save(f"{args.result_dir}/X_num_test.npy", X_test)
    np.save(f"{args.result_dir}/y_test.npy", y_test)

    # If save_as_csv is specified, also save data as CSV
    if args.save_as_csv:
        # Create column names for features
        feature_columns = [f"feature_{i}" for i in range(X_train.shape[1])]

        # Save
        train_data_combined = pd.DataFrame(
            np.hstack((X_train, y_train.reshape(-1, 1))),
            columns=feature_columns + ["target"],
        )
        train_data_combined.to_csv(f"{args.result_dir}/{args.dataset}.csv", index=False)

    format_msg = ".npy" if not args.save_as_csv else ".npy and .csv"
    print(f"Data saved to {args.result_dir} in {format_msg} format")

    dataset_metadata = {
        "task_type": "regression",  # or "regression", "multiclass", etc.
        "name": args.dataset,
        "id": f"{args.dataset}--id",
        "train_size": len(X_train),
        "val_size": len(X_val),
        "test_size": len(X_test),
        "n_num_features": X_train.shape[1],  # assuming all features are numerical
        "n_cat_features": 0,  # set to actual number if you have categorical features
    }

    print(X_train.shape)

    # Save metadata as JSON
    with open(f"{args.result_dir}/info.json", "w") as f:
        json.dump(dataset_metadata, f, indent=4)


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description="Train a diffusion model.")
    parser.add_argument(
        "--dataset",
        type=str,
        default="cifar10",
        help="Dataset to store.",
    )
    parser.add_argument(
        "--result_dir",
        type=str,
        default="data/",
        help="Directory to save results",
    )
    parser.add_argument(
        "--test_split",
        type=float,
        default=0.2,
        help="Proportion of data to use for testing.",
    )
    parser.add_argument(
        "--validation_split",
        type=float,
        default=0.1,
        help="Proportion of data to use for validation.",
    )
    parser.add_argument(
        "--num_samples",
        type=int,
        default=10000,
        help="Number of samples to generate.",
    )
    parser.add_argument(
        "--save_as_csv",
        action="store_true",
        help="Save datasets as CSV files in addition to .npy files.",
    )
    args = parser.parse_args()
    args.result_dir = args.result_dir.rstrip("/")
    args.sample_size = 1

    main(args)
