from types import SimpleNamespace
import inspect
import json
import os
import pprint
import sys
import copy
from inspect import currentframe

import torch


class NanRecorder:
    enabled = False
    a = []
    limit = 16

    @classmethod
    def start(cls, enabled):
        cls.enabled = enabled
        cls.a.clear()

    @classmethod
    def lineno(cls):
        return inspect.currentframe().f_back.f_lineno

    @classmethod
    def record(cls, t, fname):
        if not cls.enabled:
            return
        lineno = cls.lineno()
        with torch.no_grad():
            is_nan = torch.any(torch.isnan(t))
            s = SimpleNamespace(
                is_nan=is_nan,
                filename=copy.deepcopy(fname),
                lineno=copy.deepcopy(lineno),
            )
            cls.a.append(s)

    @classmethod
    def check(cls):
        if not cls.enabled:
            return
        if len(cls.a) == 0:
            return
        with torch.no_grad():
            is_nans = torch.cat([s.is_nan.view(1) for s in cls.a])
            is_nans = is_nans.tolist()
        for j in range(len(cls.a)):
            s = cls.a[j]
            if is_nans[j]:
                msg = dict(
                    rank=torch.distributed.get_rank(),
                    filename=s.filename,
                    lineno=s.lineno,
                )
                msg = f'nan detected {pprint.pformat(msg, indent=4)}'
                print(msg, flush=True)
                assert not s.is_nan.item(), 'nan detected'
        cls.a.clear()

    @classmethod
    def end(cls):
        if not cls.enabled:
            return
        cls.check()
        cls.enabled = False
