import os
import copy
from collections import OrderedDict
import torch
from tqdm import tqdm

# Environment flag
_enable_env = 'ENABLE_RECORDING'

# Global state
_global_seq = 0
_global_values = {}
_global_seq_map = {}


def enable_recording():
    """Turn on recording globally."""
    os.environ[_enable_env] = '1'
    print("Recording enabled globally.\n")


def disable_recording():
    """Turn off recording globally."""
    os.environ.pop(_enable_env, None)
    print("Recording disabled globally.\n")


def empty():
    """Reset sequence and clear all records."""
    global _global_seq, _global_values, _global_seq_map
    _global_seq = 0
    _global_values.clear()
    _global_seq_map.clear()

    print(
        # "recorder.empty() is called: "
        "All recordings cleared. \n")


def get_global_recordings():
    """Return a deep copy of all recorded tensors."""
    copied_global_values = {k: v for k, v in _global_values.items()}
    return copied_global_values
    # return copy.deepcopy(_global_values)


def get_global_recordings_order():
    """Return a deep copy of the recording order map."""
    copied_global_seq_map = {k: v for k, v in _global_seq_map.items()}
    return copied_global_seq_map
    # return copy.deepcopy(_global_seq_map)


def generic_record(uid, verbose=False, **kwargs):
    """
    Record named tensors under a unique identifier `uid`.
    Unique key is '{uid}__{var_name}', with an incremental sequence.
    """
    if os.environ.get(_enable_env) != '1':
        return

    global _global_seq, _global_values, _global_seq_map
    for var_name, tensor in kwargs.items():
        if tensor is None:
            continue

        seq = _global_seq
        key = f"{uid}__{var_name}"

        if key not in _global_values:
            _global_seq_map[key] = seq
            _global_seq += 1
            _global_values[key] = []

        # record a clone of the tensor
        recorded = tensor.clone()
        _global_values[key].append(recorded)

        count = len(_global_values[key])

        if verbose:
            print(f"{_global_seq_map[key]}: '{key}' recorded ({count} items, shape = {tuple(recorded.shape)})")


def record(parent_object, **kwargs):
    """
    Convenience wrapper: use module instance as uid.

    Example:
        record(self, hidden_states=hidden, outputs=out)
    """

    if isinstance(parent_object, torch.nn.Module):
        uid = hex(id(parent_object))
    elif isinstance(parent_object, str):
        uid = parent_object
    else:
        uid = str(parent_object)

    generic_record(uid, **kwargs)


def collect_model_recordings(root_module, use_long_name=False, stack_all=False, verbose=False):
    """
    Traverse all submodules of `root_module` and merge together
    any global recordings, renaming each key to include its full module path under 'model'.

    Args:
        root_module (nn.Module): the top-level model
        use_long_name (bool): include index, type, and ptr in key
        stack_all (bool): if True, stack all recorded tensors; otherwise take last only

    Returns:
        recorded_tensors (OrderedDict): mapping of new key -> tensor
        recorded_tensors_order (OrderedDict): mapping of sequence index -> new key
    """
    recorded_tensors = OrderedDict()
    recorded_tensors_order = OrderedDict()

    recordings = get_global_recordings()
    recordings_order = get_global_recordings_order()

    for unique_key, val_list in recordings.items():
        id_str, var = unique_key.split('__', 1)
        idx = recordings_order[unique_key]

        # default (to be overridden):
        module_full_path = f"model.{id_str}"
        model_type = ""

        # find module path by matching id
        for module_path, module in root_module.named_modules():
            if hex(id(module)) == id_str:
                module_full_path = f"model.{module_path}" if module_path else "model"
                model_type = type(module).__name__
                break

        short_name = f"{module_full_path}__{var}"
        # long_name = f"{idx}__{module_full_path}__{model_type}__{id_str}__{var}"
        long_name = f"{idx}__{module_full_path}__{var}"
        key = long_name if use_long_name else short_name

        # select tensor
        if stack_all:
            tensor = torch.stack(val_list)
        else:
            tensor = val_list[-1] if val_list else torch.tensor()
        recorded_tensors[key] = tensor.detach().cpu()
        recorded_tensors_order[idx] = key

        if verbose:
            print(f"{key}: {recorded_tensors[key].numpy().shape}")

    print("\nAll recordings collected.\n")

    return recorded_tensors, recorded_tensors_order


def group_close_tensors(records: dict[str, torch.Tensor],
                        rtol: float = 1e-5,
                        atol: float = 1e-8) -> list[list[str]]:
    """
    Compare every tensor in `records` pairwise, and group keys whose
    tensors are allclose() to each other.
    Returns a list of groups, each group is a list of record‐keys.
    """
    keys = list(records.keys())
    visited = set()
    groups = []

    for i, k0 in tqdm(enumerate(keys)):
        if k0 in visited:
            continue

        # start a new group with k0
        group = [k0]
        visited.add(k0)

        for k1 in keys[i+1:]:
            if k1 in visited:
                continue
            # if these two tensors match, add k1 to this group
            arr1 = records[k0].squeeze().float()
            arr2 = records[k1].squeeze().float()
            if arr1.shape != arr2.shape:
                continue
            if torch.allclose(arr1, arr2, rtol=rtol, atol=atol):
                group.append(k1)
                visited.add(k1)

        groups.append(group)

    return groups


def collect_and_organize_model_recordings(model, unique=True):

    print("\n", "-"*100, "\n")

    records, records_order = collect_model_recordings(
        model,

        use_long_name=False,
        # use_long_name=True,

        stack_all=False,
        # stack_all=True,
    )

    print("\n", "-"*100, "\n")

    # Example usage:
    groups = group_close_tensors(records)
    unique_records = {grp[0]: records[grp[0]] for grp in groups}

    unique_records_order = OrderedDict()
    for k in records_order.keys():
        if records_order[k] in unique_records.keys():
            unique_records_order[k] = records_order[k]

    for i, grp in enumerate(groups, 1):
        print(f"Group {i}:")
        for key in grp:
            print("  -", key)

    print("\n", "-"*100, "\n")

    # choose:
    if unique:
        records = copy.copy(unique_records)
        records_order = copy.copy(unique_records_order)

    state_records = OrderedDict()
    matrix_records = OrderedDict()

    for idx, key in records_order.items():
        if any(f"_{c}" in key for c in ('x', 'y', 'z')):
            state_records[key] = records[key]
        else:
            matrix_records[key] = records[key]

    print("\nstate recordings:\n")
    print("\n".join(state_records.keys()))

    print("\nmatrix recordings:\n")
    print("\n".join(matrix_records.keys()))

    return records, records_order, state_records, matrix_records
