from typing import Tuple
import sys
from typing import List
import numpy as np
import torch 


def to_torch(x: np.ndarray | torch.Tensor) -> torch.Tensor: 
    if isinstance(x, torch.Tensor):
        return x.to('cuda:0')
    return torch.tensor(x, device='cuda:0')


def batch_roll(x: torch.Tensor, shifts: torch.Tensor) -> torch.Tensor:
    """ 
    Shift each column of x by the corresponding shift in shifts.

    :param x: tensor of shape B x T
    :param shifts: tensor of shape B
    """
    
    assert x.ndim == 2 and shifts.ndim == 1 and x.shape[0] == shifts.shape[0]

    idx = torch.arange(x.shape[1]).unsqueeze(0).repeat(x.shape[0], 1)
    shifts = shifts.unsqueeze(1)

    shifted_idx = (idx - shifts) % x.shape[1]

    return x.gather(1, shifted_idx)


def list_split(input_list: List, num_splits: int) -> List[List]:
    """ Split a list into multiple sub-lists. """

    if num_splits > len(input_list):
        raise ValueError("Cannot split a list with more splits than its actual size.")

    # calculate the approximate size of each sublist
    avg_size = len(input_list) // num_splits
    remainder = len(input_list) % num_splits

    # initialize variables
    start = 0
    end = avg_size
    sublists = []

    for i in range(num_splits):
        # adjust sublist size for the remainder
        if i < remainder:
            end += 1

        # create a sublist and add it to the result
        sublist = input_list[start:end]
        sublists.append(sublist)

        # update the start and end indices for the next sublist
        start = end
        end += avg_size

    return sublists


def divide_grid(grid: List, n_total: int, task_id: int) -> List:
    """ Divide a grid into n_total tasks and return the task corresponding to task_id.

    :param grid: list of arguments for each call
    :param n_total: number of workers
    :param task_id: id of this worker
    """
    # total number of tasks exceeds the grid size
    if n_total > len(grid):
        n_total = len(grid)
    # task id invalid: trying to complete null tasks
    if task_id > n_total:
        return []
    # task_id corresponds to a chunk of the grid
    return list_split(grid, n_total)[task_id]


def is_debug() -> bool:
    """ Detects whether the code is running on a local system file and not on the cluster. """
    gettrace = getattr(sys, 'gettrace', None)
    if gettrace is None:
        return False
    elif gettrace():
        return True
    return False


def standardize(x: torch.Tensor, dims: Tuple, return_stats: bool = False):
    with torch.no_grad():
        x_std, x_mean = torch.std_mean(x, dim=dims, keepdims=True)
        x_std = x_std + 1e-7
    if return_stats:
        return (x - x_mean) / x_std, x_mean, x_std
    return (x - x_mean) / x_std


def nrmpe(x_ref, x, dims, stand=False, p=1.0, avg=True, normalize=True): 
    x, y = x_ref.clone(), x.clone()
    if stand:
        x, y = standardize(x, dims=dims), standardize(y, dims=dims)
    residual = x - y
    loss = residual.abs().pow(p).mean(dims)
    if normalize:
        loss /= x.abs().pow(p).mean(dims)
    loss = loss.pow(1/p)
    if avg:
        return loss.mean()
    return loss


def subsample(x, reso): 
    try:
        slices = [slice(None,None,x.shape[-n-1]//ss) for n, ss in enumerate(reso[::-1])][::-1]
        x = x[(...,) + tuple(slices)]
    except:
        print('FAILED AT ', reso, x.shape)
        thisvariabledoesntexist
    return x
