#!/usr/bin/env python3
import subprocess
import sys

# List of specific directories to download
multi_step_evaluations = [
    'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.5_length10_20250819_160803',
    'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.5_length10_20250820_113606',
    'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.5_length12_20250820_113611',
    'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.5_length15_20250820_113617',
    'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length5_20250820_113627',
    'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length10_20250820_113633',
    'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length12_20250820_113638',
    'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance2_length15_20250820_113643',
    'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.0_length12_20250820_113548',
    'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.0_length12_20250820_113542',
    'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.0_length10_20250820_113534',
    'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.0_length5_20250820_113601',
    'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance1.0_length5_20250820_113554',
    'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance0.5_length15_20250820_113509',
    'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance0.5_length12_20250820_113502',
    'root_aligned_steeredtrue_uspto_hard_reaction_type_guided_guidance0.5_length10_20250820_113443'
]

# Generate single-step experiments with reaction numbers 1-13 using wildcards for timestamps
single_step_patterns = [
    "single_step_datauspto_190first_reactions_reaction*"
]
single_step_patterns = [
    "single_step_datauspto_190first_reactions_similarity_reaction*_guidance30_length*"
]
single_step_patterns = [
    "neuralsym_uspto_190first_reactions_similarity_reaction*_steeredtrue_guidance0_length10_results100_time20250912_173155"
]
single_step_patterns = [
    'single_step_datauspto_190first_reactions_similarity_reaction*_steeredtrue_guidance30_length10_*'
] 
# single_step_patterns = [
#     'rootaligned_USPTO_full_no_overlap_uspto_190first_reactions_similarity_reaction*_steeredfalse_guidance0_length10_results100_time20250918_152637'
# ]
single_step_patterns = [
    'rootaligned_rsmiles_uspto_190_fraction0.4_thresh10_dropped_PtoR_aug5_uspto_*'
]
single_step_patterns = [
    'rootaligned_rsmiles_uspto_190_fraction0.4_thresh10_dropped_PtoR_aug5_uspto_190first_reactions_nonli*'
]
single_step_patterns = [
    'neuralsym_fraction0.4_thresh10_uspto_190first_reactions_nonlinear_*'
]
single_step_patterns = [
    'rootaligned_fraction1.0_thresh500_uspto_190first_reactions_nonlinear_with_targets_similarity_reaction*_steeredfalse_guidance0_length0_results100_time20250921_184336'
]
single_step_patterns = [
    'rootaligned_fraction1.0_thresh500_uspto_190first_reactions_with_targets_similarity_reaction*_steeredfalse_guidance0_length0_results100_time20250921_194210'
]
single_step_patterns = [
    'rootaligned_fraction0.4_thresh10_uspto_190_nonlinear_similarity_reaction*_steeredtrue_guidance30_length10_results100_time20250921_215333'
]
single_step_patterns = [
    'rootaligned_fraction1.0_thresh500_uspto_190_nonlinear_similarity_reaction*_steeredtrue_guidance30_length10_results100_time20250922_103915'
]
single_step_patterns = [
    'rootaligned_fraction1.0_thresh500_uspto_190_nonlinear_reaction_type_reaction*_steeredtrue_guidance0.3_length15_results100_time20250922_180831'
]
single_step_patterns = [
    'reaction_type_rsmiles_uspto_190_fraction1.0_thresh500_reaction*_time20250923_103438'
]
#base_remote = "/scratch/project_462000833/multiguide/experiments"
base_remote = "/scratch/project_2007775/multiguide/experiments/reaction_type"
local_dest = "/Users/laabidn1/multiguide/experiments/reaction_type"

def download_directories(directories, description):
    """Download a list of directories"""
    print(f"Starting download of {len(directories)} {description} directories...")
    successful = 0
    failed = 0
    
    for i, directory in enumerate(directories, 1):
        # For patterns with wildcards, we need to first find matching directories
        if '*' in directory:
            # Use ssh to find matching directories on the remote server
            find_cmd = ["ssh", "puhti", f"ls -d {base_remote}/{directory} 2>/dev/null || echo 'NO_MATCH'"]
            print(f"[{i}/{len(directories)}] Finding matches for {directory}...")
            
            try:
                result = subprocess.run(find_cmd, capture_output=True, text=True, check=True)
                matches = [line.strip() for line in result.stdout.strip().split('\n') if line.strip() and line.strip() != 'NO_MATCH']
                
                if not matches:
                    print(f"✗ No matches found for pattern {directory}")
                    failed += 1
                    continue
                
                # Download each matching directory
                for j, match in enumerate(matches, 1):
                    # Extract just the directory name from the full path
                    dir_name = match.split('/')[-1]
                    remote_path = f"puhti:{match}"
                    cmd = ["scp", "-r", remote_path, f"{local_dest}/"]
                    print(f"  [{j}/{len(matches)}] Downloading {dir_name}...")
                    
                    try:
                        subprocess.run(cmd, check=True)
                        print(f"  ✓ Successfully downloaded {dir_name}")
                        successful += 1
                    except subprocess.CalledProcessError:
                        print(f"  ✗ Failed to download {dir_name}")
                        failed += 1
                        
            except subprocess.CalledProcessError:
                print(f"✗ Failed to find matches for {directory}")
                failed += 1
        else:
            # Original logic for exact directory names
            remote_path = f"puhti:{base_remote}/{directory}"
            cmd = ["scp", "-r", remote_path, f"{local_dest}/"]
            print(f"[{i}/{len(directories)}] Downloading {directory}...")
            
            try:
                result = subprocess.run(cmd, check=True)
                print(f"✓ Successfully downloaded {directory}")
                successful += 1
            except subprocess.CalledProcessError:
                print(f"✗ Failed to download {directory}")
                failed += 1
        
        # Handle keyboard interrupt
        try:
            pass
        except KeyboardInterrupt:
            print("\nDownload interrupted by user")
            sys.exit(1)
    
    return successful, failed

def main():
    # Ask user which experiments to download
    print("Choose which experiments to download:")
    print("1. Multi-step evaluations only")
    print("2. Single-step evaluations only") 
    print("3. Both multi-step and single-step evaluations")
    
    choice = input("Enter your choice (1-3): ").strip()
    
    total_successful = 0
    total_failed = 0
    
    if choice == "1":
        successful, failed = download_directories(multi_step_evaluations, "multi-step")
        total_successful += successful
        total_failed += failed
    elif choice == "2":
        successful, failed = download_directories(single_step_patterns, "single-step")
        total_successful += successful
        total_failed += failed
    elif choice == "3":
        successful, failed = download_directories(multi_step_evaluations, "multi-step")
        total_successful += successful
        total_failed += failed
        
        print("\n" + "="*50)
        successful, failed = download_directories(single_step_patterns, "single-step")
        total_successful += successful
        total_failed += failed
    else:
        print("Invalid choice. Exiting.")
        sys.exit(1)
    
    print(f"\nAll downloads complete!")
    print(f"Total successful: {total_successful}")
    print(f"Total failed: {total_failed}")

if __name__ == "__main__":
    main()