import random
from functools import wraps

import lightning.pytorch as pl
import numpy as np
import torch
import torch.cuda


class SeedContext:
    def __init__(self, seed):
        self.seed = seed
        self.original_state = None

    def __enter__(self):
        if self.seed is None:
            return
        # Save the current random state
        self.original_state = {
            'random_state': random.getstate(),
            'numpy_state': np.random.get_state(),
            'torch_state': torch.get_rng_state(),
            'torch_cuda_state': torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None
        }
        # Apply the new seed
        try:
            pl.seed_everything(self.seed, verbose=False)
        except:
            pl.seed_everything(self.seed)

    def __exit__(self, exc_type, exc_value, traceback):
        if self.seed is None:
            return
        # Restore the original random state
        random.setstate(self.original_state['random_state'])
        np.random.set_state(self.original_state['numpy_state'])
        torch.set_rng_state(self.original_state['torch_state'])
        if torch.cuda.is_available():
            torch.cuda.set_rng_state_all(self.original_state['torch_cuda_state'])


def seed_context(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        seed = kwargs.get('seed', None)
        with SeedContext(seed):
            result = func(*args, **kwargs)
        return result

    return wrapper
