"""
Script to compress, copy, and decompress directories from remote machine to local machine.
"""

import subprocess
import argparse
import sys
import os
from pathlib import Path
import time


def run_remote_command(remote_host, remote_user, command, port=22):
    """
    Execute a command on the remote machine via SSH.
    
    Args:
        remote_host (str): Remote machine hostname or IP
        remote_user (str): Username for remote machine
        command (str): Command to execute
        port (int): SSH port (default: 22)
    
    Returns:
        tuple: (success: bool, stdout: str, stderr: str)
    """
    ssh_cmd = ['ssh']
    if port != 22:
        ssh_cmd.extend(['-p', str(port)])
    
    ssh_cmd.extend([f'{remote_user}@{remote_host}', command])
    
    try:
        result = subprocess.run(ssh_cmd, capture_output=True, text=True, check=True)
        return True, result.stdout.strip(), result.stderr.strip()
    except subprocess.CalledProcessError as e:
        return False, e.stdout, e.stderr


def compress_remote_directory(remote_host, remote_user, remote_dir, port=22,
                              compression='gzip'):
    """
    Compress a directory on the remote machine.
    
    Args:
        remote_host (str): Remote machine hostname or IP
        remote_user (str): Username for remote machine
        remote_dir (str): Remote directory path to compress
        port (int): SSH port (default: 22)
        compression (str): Compression type ('gzip', 'bzip2', 'xz')
    
    Returns:
        tuple: (success: bool, compressed_file_path: str)
    """
    # Determine compression command and extension
    if compression == 'gzip':
        compress_cmd = 'tar -czf'
        ext = '.tar.gz'
    elif compression == 'bzip2':
        compress_cmd = 'tar -cjf'
        ext = '.tar.bz2'
    elif compression == 'xz':
        compress_cmd = 'tar -cJf'
        ext = '.tar.xz'
    else:
        raise ValueError(f"Unsupported compression type: {compression}")
    
    # Create compressed filename
    remote_path = Path(remote_dir)
    # Remove extension if it exists
    if remote_path.suffix:
        remote_path_no_ext = remote_path.with_suffix('')
    else:
        remote_path_no_ext = remote_path

    compressed_filename = f"{remote_path_no_ext.name}_{int(time.time())}{ext}"
    compressed_path = f"{remote_path.parent}/{compressed_filename}"
    
    # Build compression command
    # Use -C to change to parent directory and compress relative path
    parent_dir = str(remote_path.parent)
    dir_name = remote_path.name
    
    command = f"cd {parent_dir} && {compress_cmd} {compressed_path} {dir_name}"
    
    print(f"Compressing {remote_dir} on remote machine...")
    success, stdout, stderr = run_remote_command(remote_host, remote_user, command, port)
    
    if success:
        print(f"Successfully compressed to {compressed_path}")
        return True, compressed_path
    else:
        print(f"Failed to compress directory: {stderr}")
        return False, None


def copy_compressed_file(remote_host, remote_user, remote_file, local_dir, port=22):
    """
    Copy the compressed file from remote to local machine using rsync.
    
    Args:
        remote_host (str): Remote machine hostname or IP
        remote_user (str): Username for remote machine
        remote_file (str): Remote compressed file path
        local_dir (str): Local directory to copy to
        port (int): SSH port (default: 22)
    
    Returns:
        tuple: (success: bool, local_file_path: str)
    """
    # Get path from after results folder and before the file name
    try:
        folders = remote_file.split('/')
        add_path = Path('/'.join(folders[folders.index('results')+1:-1]))
    except:
        add_path = Path('')
    # Ensure local directory exists
    local_path = Path(local_dir) / add_path
    local_path.mkdir(parents=True, exist_ok=True)
    
    # Build rsync command
    cmd = ['rsync', '-avz', '--progress']
    
    # Add SSH options for custom port
    if port != 22:
        cmd.extend(['-e', f'ssh -p {port}'])
    
    # Add source and destination
    source = f"{remote_user}@{remote_host}:{remote_file}"
    local_file_path = local_path / Path(remote_file).name
    cmd.extend([source, str(local_path)])
    
    print(f"Copying compressed file: {' '.join(cmd)}")
    
    try:
        subprocess.run(cmd, check=True)
        print("File copied successfully!")
        return True, str(local_file_path)
    except subprocess.CalledProcessError as e:
        print(f"Error during rsync: {e}")
        return False, None


def decompress_local_file(local_file_path, extract_dir):
    """
    Decompress the local file and extract to specified directory.
    
    Args:
        local_file_path (str): Path to compressed file
        extract_dir (str): Directory to extract to
    
    Returns:
        bool: True if successful, False otherwise
    """
    local_path = Path(local_file_path)
    extract_path = Path(extract_dir)
    extract_path.mkdir(parents=True, exist_ok=True)
    
    # Determine decompression command based on file extension
    if local_path.suffix == '.gz' and '.tar' in local_path.name:
        cmd = ['tar', '-xzf', str(local_path), '-C', str(extract_path)]
    elif local_path.suffix == '.bz2' and '.tar' in local_path.name:
        cmd = ['tar', '-xjf', str(local_path), '-C', str(extract_path)]
    elif local_path.suffix == '.xz' and '.tar' in local_path.name:
        cmd = ['tar', '-xJf', str(local_path), '-C', str(extract_path)]
    else:
        print(f"Unsupported file format: {local_path.suffix}")
        return False
    
    print(f"Decompressing {local_file_path}...")
    try:
        subprocess.run(cmd, check=True)
        print(f"Successfully decompressed to {extract_dir}")
        return True
    except subprocess.CalledProcessError as e:
        print(f"Error during decompression: {e}")
        return False


def cleanup_remote_file(remote_host, remote_user, remote_file, port=22):
    """
    Delete the compressed file from the remote machine.
    
    Args:
        remote_host (str): Remote machine hostname or IP
        remote_user (str): Username for remote machine
        remote_file (str): Remote file path to delete
        port (int): SSH port (default: 22)
    
    Returns:
        bool: True if successful, False otherwise
    """
    command = f"rm -f {remote_file}"
    print(f"Cleaning up remote compressed file: {remote_file}")
    
    success, stdout, stderr = run_remote_command(remote_host, remote_user, command, port)
    
    if success:
        print("Remote compressed file cleaned up successfully")
        return True
    else:
        print(f"Warning: Failed to clean up remote file: {stderr}")
        return False


def main():
    parser = argparse.ArgumentParser(description='Compress, copy, and decompress directory from remote machine')

    parser.add_argument('-ru', '--remote_user', help='Username for remote machine')
    parser.add_argument('-rd', '--remote_dir', help='Remote directory path to copy')
    parser.add_argument('-rh', '--remote_host', default='quest.northwestern.edu', help='Remote machine hostname or IP address')
    parser.add_argument('-ld', '--local_dir', default='./quest_results', help='Local destination directory')
    parser.add_argument('-p', '--port', type=int, default=22, help='SSH port (default: 22)')
    parser.add_argument('-c', '--compression', choices=['gzip', 'bzip2', 'xz'], 
                       default='gzip', help='Compression type (default: gzip)')
    parser.add_argument('--keep-compressed', action='store_true', 
                       help='Keep the compressed file locally after extraction')
    parser.add_argument('--keep-remote', action='store_true',
                       help='Keep the compressed file on remote machine')
    
    args = parser.parse_args()
    
    # Validate inputs
    if not args.remote_host or not args.remote_user or not args.remote_dir or not args.local_dir:
        print("Error: All arguments (remote_host, remote_user, remote_dir, local_dir) are required")
        sys.exit(1)
    
    print(f"Starting compress-copy-decompress operation...")
    print(f"Remote: {args.remote_user}@{args.remote_host}:{args.remote_dir}")
    print(f"Local: {args.local_dir}")
    print(f"Compression: {args.compression}")
    print("-" * 50)
    
    # Step 1: Compress directory on remote machine
    success, compressed_file = compress_remote_directory(
        args.remote_host, args.remote_user, args.remote_dir, 
        args.port, args.compression
    )

    if not success:
        print("Failed to compress remote directory")
        sys.exit(1)
    
    try:
        # Step 2: Copy compressed file to local machine
        success, local_compressed_file = copy_compressed_file(
            args.remote_host, args.remote_user, compressed_file,
            args.local_dir, args.port
        )
        
        if not success:
            print("Failed to copy compressed file")
            sys.exit(1)
        
        # Step 3: Decompress file locally
        extract_dir = Path(local_compressed_file).parent
        success = decompress_local_file(local_compressed_file, extract_dir)
        
        if not success:
            print("Failed to decompress file locally")
            sys.exit(1)
        
        # Step 4: Clean up compressed file locally (unless --keep-compressed)
        if not args.keep_compressed:
            try:
                os.remove(local_compressed_file)
                print(f"Removed local compressed file: {local_compressed_file}")
            except OSError as e:
                print(f"Warning: Failed to remove local compressed file: {e}")
        
        print("-" * 50)
        print(f"Successfully copied and extracted {args.remote_dir} to {extract_dir}")
        
    finally:
        # Step 5: Clean up remote compressed file (unless --keep-remote)
        if not args.keep_remote:
            cleanup_remote_file(args.remote_host, args.remote_user, compressed_file, args.port)
    
    sys.exit(0)


if __name__ == "__main__":
    main()