import argparse
import os

import boto3
import sagemaker
from dotenv import load_dotenv
from sagemaker.pytorch import PyTorch


def get_instance_gpus(instance_type):
    gpus_dict = {
        "ml.p3.8xlarge": 4,
        "ml.p3.16xlarge": 8,
        "ml.p3dn.24xlarge": 8,
        "ml.p4d.24xlarge": 8,
        "ml.p4de.24xlarge": 8,
        "ml.p5.48xlarge": 8,
        "ml.p2.8xlarge": 8,
        "ml.p2.8xlarge": 16,
        "ml.g4dn.12xlarge": 4,
        "ml.g5.12xlarge": 4,
        "ml.g5.24xlarge": 4,
        "ml.g5.48xlarge": 8,
    }

    return gpus_dict[instance_type] if instance_type in gpus_dict else 1


load_dotenv()
iam_client = boto3.client("iam")
role = iam_client.get_role(RoleName=os.getenv("SAGEMAKER_IAM_ROLE"))["Role"]["Arn"]
sess = sagemaker.Session(default_bucket_prefix="sagemaker")

parser = argparse.ArgumentParser()
parser.add_argument("entry_point", type=str, default="run_deepspeed.py", help="Path to the script relative to source directory.")
parser.add_argument("--instance_type", "-i", type=str, default="ml.g5.12xlarge", help="AWS compute instance that is used.")
parser.add_argument("--base_job_name", "-n", type=str, default=None, help="Sagemaker unique job name")
parser.add_argument("--distributed", "-d", type=str, default=None, help="Whether to run the script with the torch.distributed framework")
parser.add_argument("--config_dir", "-c", type=str, default="gpt2", help="Config dir to use when running the script")
parser.add_argument("--wandb_run_id", "-w", type=str, default=None, help="WandB logging unique id")
args, env_args = parser.parse_known_args()

distributed = get_instance_gpus(args.instance_type) > 1 if args.distributed is None else "t" in args.distributed.lower()
environment = {"OMP_NUM_THREADS": "8", "DATA_DIR": "/opt", "CONFIG_DIR": args.config_dir, **dict([tuple(arg.split("=")) for arg in env_args])}

if args.wandb_run_id is not None:
    environment["WANDB_RUN_ID"] = args.wandb_run_id

training_environment = PyTorch(
        entry_point=args.entry_point,
        source_dir=".",
        instance_type=args.instance_type,
        base_job_name=args.base_job_name,
        instance_count=1,
        framework_version="2.0.1",
        py_version="py310",
        distribution={"torch_distributed": {"enabled": distributed}},
        environment=environment,
        role=role,
        sagemaker_session=sess,
        disable_output_compression=True,
        output_path=f"s3://{sess.default_bucket()}/sagemaker",
        checkpoint_s3_uri=f"s3://{os.getenv('S3_BUCKET')}/data",
        checkpoint_local_path="/opt/data",
        disable_profiler=True,
        max_run=5 * 24 * 60 * 60
)

training_environment.fit()
