import os
import re
import pytz
import torch
import torch.nn as nn
from datetime import datetime


def get_rundir(stem, root="runs", digit_count=4):
    dirs = [d for d in os.listdir(root) if re.match(f"{stem}\d{{{digit_count}}}$", d)]
    numbers = sorted(set(int(re.search(r"\d+", d).group()) for d in dirs))
    mex = 0
    for num in numbers:
        if mex != num:
            break
        mex += 1

    rundir = f"{stem}{mex:04}"
    while True:
        try:
            os.makedirs(os.path.join(root, rundir))
            timestamp = datetime.now(pytz.timezone("America/New_York"))
            time_string = timestamp.strftime("%Y-%m-%d %H:%M:%S %Z\n")
            with open(os.path.join(root, rundir, "timestamp"), "w") as f:
                f.write(time_string)
            return rundir
        except FileExistsError:
            mex += 1
            rundir = f"{stem}{mex:04}"

class RelativeMSELoss(nn.Module):
    def __init__(self):
        super(RelativeMSELoss, self).__init__()

    def forward(self, y_pred, y_true):
        relative_error = torch.square(y_pred - y_true) / torch.max(torch.square(y_true), torch.ones_like(y_true))
        return torch.mean(relative_error)
