import torch
import time
import argparse


# define the countdown func.
def countdown(t):

    while t:
        mins, secs = divmod(t, 60)
        timer = "{:02d}:{:02d}".format(mins, secs)
        print(timer, end="\r")
        time.sleep(1)
        t -= 1

    print("Fire in the hole!!")


def occupy_gpu_memory(gb_to_occupy, cuda="1"):
    """
    Occupies approximately `gb_to_occupy` GB of GPU memory.

    Parameters:
    gb_to_occupy (float): The amount of GPU memory to occupy in gigabytes.
    """
    # Calculate the number of elements to allocate to get roughly the desired amount of GPU memory.
    # float32 takes 4 bytes, and we multiply by 1e9 to convert gigabytes to bytes.
    num_elements = int(gb_to_occupy * (1e9 / 4))

    # Allocate a tensor of zeros directly on the GPU.
    tensor = torch.zeros(num_elements).to(f"cuda:{cuda}")
    print(f"Allocated ~{gb_to_occupy} GB of GPU memory on CUDA:{cuda}.")
    return tensor


def occupy_gpu_computation(tensor, duration_sec=60):
    """
    Performs heavy computation on the GPU for a specified duration while keeping the GPU memory occupied.

    Parameters:
    tensor (torch.Tensor): The tensor occupying GPU memory.
    duration_sec (int): Duration in seconds to perform computation.
    """
    print(f"Starting computation for {duration_sec} seconds...")
    start_time = time.time()
    device = tensor.device
    size = 5000  # Size of the square matrix for multiplication. Adjust based on your GPU capability.
    matrix1 = torch.rand(size, size, device=device)
    matrix2 = torch.rand(size, size, device=device)

    while (time.time() - start_time) < duration_sec:
        _ = torch.matmul(matrix1, matrix2)
    print("Completed computation.")


def free_gpu_memory(tensor):
    """
    Frees the memory occupied by a tensor.

    Parameters:
    tensor (torch.Tensor): The tensor to free.
    """
    del tensor
    torch.cuda.empty_cache()
    print("Freed the GPU memory occupied by the tensor.")


args = argparse.ArgumentParser()
args.add_argument("--dataset", default="sst5")
args.add_argument("--start", type=int, default=0, help="cuda number")
args.add_argument("--end", type=int, default=0, help="cuda number")
args.add_argument("--dur", type=float, default=1, help="days")

args = args.parse_args()

if args.end is None:
    args.end = args.start

tensors = []

# Example usage:
for num in range(args.start, args.end + 1):
    tensor = occupy_gpu_memory(20, str(num))  # Attempt to occupy ~1 GB of GPU memory.
    tensors.append(tensor)
# Do something with the tensor if you want...


dur_sec = int(args.dur * 24 * 60 * 60)
# dur_sec = 20
for tensor in tensors:
    occupy_gpu_computation(tensor, dur_sec)

# function call
countdown(int(dur_sec))


# When done, free the memory.
for tensor in tensors:
    free_gpu_memory(tensor)
