from dotenv import load_dotenv
load_dotenv()
import argparse
import os
from clearml import Dataset
import pandas as pd

CLEARML_PROJECT_NAME = os.environ['CLEARML_PROJECT_NAME']

def get_parser():
    parser = argparse.ArgumentParser(description="Download a dataset from ClearML")
    parser.add_argument(
        '--task', 
        type=str, 
        default='profession', 
        help="Dataset task name"
    )
    parser.add_argument(
        '--sample_size', 
        type=int, 
        default=250, 
        help="Dataset sample size")
    parser.add_argument(
        '--dataset_id', 
        type=str, 
        help="ClearML dataset ID (if provided, task and sample_size are ignored)"
    )
    parser.add_argument(
        '--output_path', 
        type=str, 
        default='model_data/dataset_samples/', 
        help="Local path to save the downloaded dataset"
    )
    return parser


if __name__ == '__main__':
    parser = get_parser()
    args = parser.parse_args()
    
    # Determine dataset to download
    if args.dataset_id:
        dataset_id = args.dataset_id
        print(f"Downloading dataset with ID: {dataset_id}")
    else:
        dataset_name = f'{args.task}_sample_{args.sample_size}'
        print(f"Searching for dataset: {dataset_name} in project {CLEARML_PROJECT_NAME}")
        # Get the latest version of the dataset
        dataset_obj = Dataset.get(
            dataset_project=CLEARML_PROJECT_NAME,
            dataset_name=dataset_name,
            only_completed=True,
            auto_create=False
        )
        dataset_id = dataset_obj.id
        print(f"Found dataset with ID: {dataset_id}")
    
    # Download the dataset
    local_dataset_path = Dataset.get(dataset_id=dataset_id).get_local_copy()
    print(f"Dataset downloaded to: {local_dataset_path}")
    
    # Create output directory if it doesn't exist
    os.makedirs(args.output_path, exist_ok=True)
    
    # Find the CSV file in the downloaded data
    csv_files = [f for f in os.listdir(local_dataset_path) if f.endswith('.csv')]
    if csv_files:
        csv_path = os.path.join(local_dataset_path, csv_files[0])
        # Load the dataset to verify it works
        df = pd.read_csv(csv_path)
        print(f"Successfully loaded dataset with {len(df)} samples")
        print(f"Sample of data:\n{df.head()}")
        
        # Optionally save to the specified output path
        output_file = os.path.join(args.output_path, csv_files[0])
        df.to_csv(output_file, index=False)
        print(f"Dataset saved to: {output_file}")
    else:
        print("No CSV files found in the downloaded dataset")
