import torch
import torch.nn as nn
from torch import Tensor
from typing import Tuple, Union, Callable
import time
import numpy as np  

def cpu_timer(f: Callable, *args, **kwargs):
    start = time.perf_counter()
    res = f(*args, **kwargs)
    return time.perf_counter() - start, res

def cuda_timer(device: Union[torch.device, int], f: Callable, *args, **kwargs):
    torch.cuda.set_device(device)
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    res = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize(device)
    return start.elapsed_time(end), res

# def cal_fun_t(training: bool, n: int, device: Union[str, torch.device, int], f: Callable, x: torch.Tensor):
#     x = x.clone().detach().requires_grad_(True)
#     if not x.is_leaf:
#         x.retain_grad()
#     t_list = []
    
#     if device == 'cpu':
#         _, y = cpu_timer(f, x)
#         if training:
#             y_grad = torch.rand_like(y)
#             y.backward(y_grad)

#         for _ in range(n * 2):
#             t = 0
#             t_f, y = cpu_timer(f, x)
#             t += t_f    
#             if training:
#                 y_grad = torch.rand_like(y)
#                 t_b, _ = cpu_timer(device, y.backward, y_grad)
#                 t += t_b
#             t_list.append(t)
            
#     else:
#         _, y = cuda_timer(device, f, x)

#         if training:
#             y_grad = torch.rand_like(y)
#             y.backward(y_grad)

#         for i in range(n * 2):
#             t = 0
#             t_f, y = cuda_timer(device, f, x)
#             t += t_f
            
#             if training:
#                 y_grad = torch.rand_like(y)
#                 t_b, _ = cuda_timer(device, y.backward, y_grad)
#                 t += t_b

#             t_list.append(t)
#     t_list = np.asarray(t_list)
#     return t_list[n:].mean()

def cal_fun_t(training: bool, n: int, device: Union[str, torch.device, int], f: Callable, x: torch.Tensor):
    x = x.clone().detach().requires_grad_(True)
    if not x.is_leaf:
        x.retain_grad()
    t_list = []
    
    if device == 'cpu':
        _, y = cpu_timer(f, x)
        if training:
            y_grad = torch.rand_like(y)
            y.backward(y_grad)
            y_grad = None  

        for _ in range(n * 2):
            t = 0
            t_f, y = cpu_timer(f, x)
            t += t_f    
            if training:
                y_grad = torch.rand_like(y)
                t_b, _ = cpu_timer(device, y.backward, y_grad)
                t += t_b
                y_grad = None  
            t_list.append(t)
            
    else:
        _, y = cuda_timer(device, f, x)

        if training:
            y_grad = torch.rand_like(y)
            y.backward(y_grad)
            y_grad = None  
            torch.cuda.empty_cache()


        for i in range(n * 2):
            t = 0
            t_f, y = cuda_timer(device, f, x)
            t += t_f
            if training:
                y_grad = torch.rand_like(y)
                t_b, _ = cuda_timer(device, y.backward, y_grad)
                t += t_b
                y_grad = None 
                torch.cuda.empty_cache()

            t_list.append(t)
    t_list = np.asarray(t_list)

    return t_list[n:].mean()

def t_last_multi_step_forward(x_seq: Tensor, single_step_module: Union[nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable]):
    """
    * :ref:`API in English <t_last_multi_step_forward-en>`

    .. _t_last_multi_step_forward-cn:

    :param x_seq: ``shape=[batch_size, ..., T]`` 的输入tensor
    :type x_seq: Tensor
    :param single_step_module: 一个或多个单步模块
    :type single_step_module: Union[nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable]
    :return: ``shape=[batch_size, ..., T]`` 的输出tensor
    :rtype: torch.Tensor

    在单步模块 ``single_step_module`` 上使用多步前向传播。

    * :ref:`中文 API <t_last_multi_step_forward-cn>`

    .. _t_last_multi_step_forward-en:

    :param x_seq: the input tensor with ``shape=[batch_size, ..., T]``
    :type x_seq: torch.Tensor
    :param single_step_module: one or many single-step modules
    :type single_step_module: Union[nn.Module, list[nn.Module], tuple[nn.Module], nn.Sequential, Callable]
    :return: the output tensor with ``shape=[batch_size, ..., T]``
    :rtype: torch.torch.Tensor

    Applies multi-step forward on ``single_step_module``.

    """
    y_seq = []
    if isinstance(single_step_module, (list, tuple, nn.Sequential)):
        for t in range(x_seq.shape[-1]):
            x_seq_t = x_seq[..., t]
            for m in single_step_module:
                x_seq_t = m(x_seq_t)
            y_seq.append(x_seq_t)
    else:
        for t in range(x_seq.shape[-1]):
            y_seq.append(single_step_module(x_seq[..., t]))

    return torch.stack(y_seq, dim=-1)


def t_last_vmap_forward(x_seq: Tensor, stateless_module: Union[nn.Module, list, tuple, nn.Sequential, Callable]):
    """
    * :ref:`API in English <t_last_seq_to_ann_forward-en>`

    .. _t_last_seq_to_ann_forward-cn:

    :param x_seq: ``shape=[batch_size, ..., T]`` 的输入tensor
    :type x_seq: Tensor
    :param stateless_module: 单个或多个无状态网络层
    :type stateless_module: Union[nn.Module, list, tuple, nn.Sequential, Callable]
    :return: the output tensor with ``shape=[batch_size, ..., T]``
    :rtype: Tensor

    使用无状态层进行多步前向传播。

    .. note::
        SpikingJelly中默认序列数据的 ``shape=[T, batch_size, ...]``，但此函数是用于另一种格式，即 ``shape=[batch_size, ..., T]``。当使用 ``torch >= 2.0.0`` 时也有并行加速的效果。

    .. note::
        不能用于BN层，因为BN层的running mean/var是输入依赖的。对于BN层，只需要输入被当作是 ``shape = [N, C, ..]`` 即可并行计算，需要用户手动实现。


    * :ref:`中文 API <t_last_seq_to_ann_forward-cn>`

    .. _t_last_seq_to_ann_forward-en:

    :param x_seq: the input tensor with ``shape=[batch_size, ..., T]``
    :type x_seq: Tensor
    :param stateless_module: one or many stateless modules
    :type stateless_module: Union[nn.Module, list, tuple, nn.Sequential, Callable]
    :return: the output tensor with ``shape=[batch_size, ..., T]``
    :rtype: Tensor

    Applied forward on stateless modules.

    .. admonition:: Note
        :class: note

        The default shape of sequence data in SpikingJelly is ``shape=[T, batch_size, ...]``. However, this function is used for the other data format where  ``shape=[batch_size, ..., T]``. When using ``torch >= 2.0.0``, this function is computing in parallel.

    .. admonition:: Note
        :class: note

        This function can not be applied to wrap BN because its running mean/var depends on inputs. The BN can be computed in parallel as long as the input is regarded as ``shape = [N, C, ..]``, which can be implemented by user manually.
    """

    if hasattr(torch, 'vmap'):
        vmap_f = torch.vmap(stateless_module, in_dims=-1, out_dims=-1)
        return vmap_f(x_seq)
    else:
        return t_last_multi_step_forward(x_seq, stateless_module)
   
