import paramiko
import os
from typing import Union

def get_run_name(
    topo_scale: int
):
    if topo_scale>0:
        run_name = f"apply_nesim_every_n_steps_1_nesim_config_scale_{topo_scale}_shrink_factor_[9.0]_layer_names_all_layers_c_fc_checkpoint_every_n_steps_100_num_warmup_steps_3000_batch_size_16_context_length_1024"
    elif topo_scale == 0:
        # baseline
        run_name = f"apply_nesim_every_n_steps_1_nesim_config_baseline_shrink_factor_[9.0]_layer_names_all_layers_c_fc_checkpoint_every_n_steps_100_num_warmup_steps_3000_batch_size_8_context_length_1024"
    else:
        raise ValueError(f"Invalid topo_scale: {topo_scale}\nIt should either be an int (0 for baseline)")
    return run_name

    
def run_remote_command(hostname, port=22, username = "XXXX-1", command: str = "ls", password = None):
    """
    Executes a command on a remote machine via SSH and returns the output.
    
    Parameters:
    hostname (str): The IP address or hostname of the remote machine.
    port (int): The port number to connect to (default SSH port is 22).
    username (str): The username for SSH authentication.
    command (str): The command to be executed on the remote machine.
    
    Returns:
    str: The output of the command.
    """
    # Create a new SSH client
    client = paramiko.SSHClient()
    
    # Automatically add the remote server's SSH key (no need for user interaction)
    client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
    
    # Connect to the remote server
    client.connect(hostname, port=port, username=username, password=password)
    
    # Execute the command
    stdin, stdout, stderr = client.exec_command(command)
    
    # Read the output from stdout
    output = stdout.read().decode()

    client.close()
    
    return output



def download_checkpoint_from_remote(machines, machine_name, cache_dir, topo_scale, global_step, run = True):
    machine_details = machines[machine_name]

    machine_ssh_thing = f"{machine_details['username']}@{machine_details['hostname']}"
    
    run_name = get_run_name(topo_scale=topo_scale)
    checkpoint_path_source = os.path.join(
        machine_details["checkpoints_dir"],
        run_name,
        f"checkpoint-{global_step}",
        "pytorch_model.bin"
    )

    os.system(f"rm -rf {cache_dir} && mkdir -p {cache_dir}")
    command = f"scp -r {machine_ssh_thing}:{checkpoint_path_source} {cache_dir}"

    if run:
        print(f"[DOWNLOADING]")
        os.system(
            command=command
        )
        return os.path.join(cache_dir, "pytorch_model.bin")
    else:
        return command

def upload_checkpoint_to_remote(
    machines, machine_name, checkpoint_path, topo_scale, global_step, run = True, password = None
):
    machine_details = machines[machine_name]
    run_name = get_run_name(topo_scale=topo_scale)
    machine_ssh_thing = f"{machine_details['username']}@{machine_details['hostname']}"

    assert os.path.exists(path = checkpoint_path), f"Invalid checkpoint_path. The most likely reason for this error is that you tried to download a checkpoint which does not exist in the first place, hence nothing was downloaded into the cache dir. Please ssh into the machine and check if it exists"
    destination_path = os.path.join(
        machine_details["checkpoints_dir"],
        run_name,
        f"checkpoint-{global_step}"
    )
    output =run_remote_command(
        hostname=machine_details["hostname"],
        username=machine_details["username"],
        command=f"mkdir -p {destination_path}",
        password=password
    )
    # print(f'Output from remote for mkdir (ideally should be empty):\n{output}\n[END OF OUTPUT]')

    command = f"scp  {checkpoint_path} {machine_ssh_thing}:{destination_path}"

    if run:
        print(f"[UPLOADING]")
        os.system(command)
        return os.path.join(
            destination_path,
            "pytorch_model.bin"
        )
    else:
        return command