import re
import os
import os.path as pt
import torch


class NanGradientsError(RuntimeError):
    pass


def lst_of_lsts(n: int):  # initiates a list of n empty lists
    return [[] for _ in range(n)]


def weight_reset(m: torch.nn.Module):  # resets the weights of the given module
    reset_parameters = getattr(m, "reset_parameters", None)
    if callable(reset_parameters):
        m.reset_parameters()


def int_set_to_str(intset: set[int]) -> str:
    return '+'.join([str(i) for i in sorted(intset)])


def str_to_int_set(intset_str: str) -> set[int]:
    return {int(i) for i in intset_str.split('+')}


def find_permutation_snapshot(snapshot: str, requested_intset: set[int]) -> str:  # legacy, old intset_strs were not sorted
    if not pt.exists(snapshot):
        requested_intset_strs = re.findall('cls.*_', pt.basename(snapshot))
        if len(requested_intset_strs) == 0:
            return snapshot
        requested_intset_str = requested_intset_strs[0][3:-1]

        for candidate in os.listdir(pt.dirname(snapshot)):
            candidate_intset_strs = re.findall('cls.*_', candidate)
            if len(candidate_intset_strs) > 0:
                candidate_intset_str = candidate_intset_strs[0][3:-1]
                candidate_intset = str_to_int_set(candidate_intset_str)
                if candidate_intset == requested_intset:
                    snapshot = pt.join(
                        pt.dirname(snapshot), pt.basename(snapshot).replace(requested_intset_str, candidate_intset_str)
                    )
                    return snapshot

    return snapshot
