# %%
import math
import re
import torch
from torch import Tensor as TT
from torch.utils.data import DataLoader
from typing import Callable, Union
TensFunc2 = Callable[[TT, TT], TT]
TensFunc = Callable[[TT], TT]
normal_dist = torch.distributions.Normal(0, 1)


def sortfilesby(files, string: str):
    index = []
    for file in files:
        match = re.search(string, file)
        index.append(float(match.group(1)))
    return [x for _, x in sorted(zip(index, files))], sorted(index)


def g_error(y: TT, x: TT, g_hat: TensFunc2, g_true: TensFunc2, summary_func: TensFunc = torch.mean,
            batch_size=None) -> TT:
    errors = []
    if batch_size is None:
        ys = [y]
        xs = [x]
    else:
        ys = torch.split(y, batch_size)
        xs = torch.split(x, batch_size)

    for i in range(len(ys)):
        y_fake = g_hat(ys[i], xs[i])
        y_true = g_true(ys[i], xs[i])
        errors.append(torch.abs(y_fake.squeeze() - y_true.squeeze()))
    return summary_func(torch.cat(errors))


def abs_diff(y1: torch.Tensor, y2: torch.Tensor) -> torch.Tensor:
    return torch.abs(y1-y2)


def gen_error(data: Union[TT, list], h_hat: Callable, h_true: Callable,
              summary_func: TensFunc = torch.mean, diff_func: TensFunc2 = abs_diff, batch_size=None) -> TT:
    """gives the average error of two estimates given data and those estimates with options for batching.

    Args:
        data (Union[TT, list]): The test data top obtain errors on
        h_hat (Callable): estimated function
        h_true (Callable): True function
        summary_func (Callable, optional): Function to summarise errors. Defaults to torch.mean.
        diff_func (Callable, optional): Function to give errors given true values and estimates provided in that order.
                                         Defaults to abs_diff.
        batch_size (Union[int, None], optional): Size of batch for computing error. If None defaults to full data.
                                                 Defaults to None.

    Returns:
        TT: Single number giving the average error of the estimator.
    """
    errors = []
    if type(data) is TT:
        data = [data]
    if batch_size is None:
        split_data = [[datum] for datum in data]
    else:
        split_data = [torch.split(datum, batch_size) for datum in data]

    for i, temp_data in enumerate(zip(*split_data)):
        y_fake = h_hat(*temp_data)
        y_true = h_true(*temp_data)
        try:
            diffs = diff_func(y_true, y_fake)
        except RuntimeError as e:
            print(e)
            return
        errors.append(diffs)
    return summary_func(torch.cat(errors))


def torch_normpdf(x: torch.Tensor) -> torch.Tensor:
    return torch.exp(-0.5 * x**2) / math.sqrt(2 * math.pi)


def torch_normcdf(x: torch.Tensor) -> torch.Tensor:
    return 0.5 * (1 + torch.erf(x / math.sqrt(2)))


def torch_normicdf(prob: torch.Tensor) -> torch.Tensor:
    return torch.erfinv(2 * prob - 1) * math.sqrt(2)


def get_true_g(gs_0: list, gs_1: list):
    def true_g(y, x):
        return gs_1[1](x)*(y-gs_0[0](x))/gs_0[1](x)+gs_1[0](x)
    return true_g


def get_true_h(hs_0: list, hs_1: list, base_cdf: TensFunc2 = torch_normcdf):
    def true_h(y0, y1, x):
        # transform y0 and y1
        y0_t = (y0 - hs_0[0](x))/hs_0[1](x)
        y1_t = (y1 - hs_1[0](x))/hs_1[1](x)
        # get cdfs
        cdf0 = base_cdf(y0_t)
        cdf1 = base_cdf(y1_t)
        # get the difference
        return cdf1 - cdf0
    return true_h


def torch_nanstd(X: torch.Tensor, dim):
    squared_diff = (X - torch.nanmean(X, dim=dim, keepdim=True))**2
    return torch.sqrt(torch.nanmean(squared_diff, dim=dim))


def my_all(X: torch.Tensor, dim=None):
    if dim is None:
        dim = tuple(range(len(X.shape)))
    return torch.sum(X, dim=dim) == torch.prod(torch.tensor(X.shape)[list(dim)])


def my_any(X: torch.Tensor, dim):
    return torch.sum(X, dim=dim) > 0


def my_allclose(X: torch.Tensor, Y: torch.Tensor, rtol=1e-05, atol=1e-08, dim=None):
    if dim is None:
        dim = tuple(range(len(X.shape)))
    logic_mat = torch.abs(X-Y) < atol + rtol * torch.abs(Y)
    return my_all(logic_mat, dim=dim)


def dist_to_symettry(A: torch.Tensor):
    dims = len(A.shape)
    return 0.5*torch.norm(A-A.permute((*tuple(range(dims-2)), dims-1, dims-2)), dim=(dims-1, dims-2))


def torch_nancov(X: torch.Tensor):
    d = X.shape[0]
    nan_bool = torch.isnan(X)
    cov_mat = torch.empty((d, d))
    inds = torch.tril_indices(d, d, -1)
    for ind in inds.T:
        X_sub = X[ind, :]
        nan_sub = nan_bool[ind, :]
        X_sub = X_sub[:, ~my_any(nan_sub, 0)]
        temp_cov = torch.cov(X_sub)
        for i, sub_ind in enumerate(ind):
            cov_mat[sub_ind, ind] = temp_cov[i, :]
    return cov_mat


def get_ci(vec: torch.Tensor, dim: int, alpha: float = 0.05, na_rm=False, verbose=False) -> torch.Tensor:
    """Get mean and CI for mean from vector

    Args:
        vec (Tensor): The vector of values you want the C.I. for the mean from
        dim (int): The dimension to calculate the mean and CI over
        verbose (bool, optional): Whether or not to print the CI. Defaults to True.
    """
    n = vec.shape[dim]
    dist_value = normal_dist.icdf(torch.tensor(1-alpha/2))
    if na_rm:
        n_samples = torch.sum(~torch.isnan(vec), dim=dim)
        mean = torch.nanmean(vec, dim=dim)
        se = torch_nanstd(vec, dim=dim)/(n_samples**0.5)
    else:
        mean = torch.mean(vec, dim=dim)
        se = torch.std(vec, dim=dim)/(n**0.5)
    ci_up = mean+dist_value*se
    ci_low = mean-dist_value*se
    if verbose:
        if mean.shape[0] > 1:
            print(f"Our Estimated Expected Power is: {mean}")
            print(f"With ci({ci_low}, {ci_up})")
        else:
            print(f"Our Estimated Expected Power is: {mean:4.3f}")
            print(f"With ci({ci_low:4.3f}, {ci_up:4.3f})")
    return torch.stack([mean, ci_low, ci_up], dim=0)


def recursive_tensorize(list_of_lists) -> torch.Tensor:
    """Recursively convert a list of lists to a tensor via iterative implementation of stack on the 0th dimension

    Args:
        list_of_lists (list[list[torch.Tensor]]): A nested list of lists of tensors to convert

    Returns:
        torch.Tensor: The converted tensor
    """
    if isinstance(list_of_lists, list):
        if isinstance(list_of_lists[0], (list, tuple, torch.Tensor)):
            return torch.stack([recursive_tensorize(list) for list in list_of_lists])
        else:
            return torch.tensor(list_of_lists)
    else:
        return list_of_lists


def list_of_dicts_transpose(list_of_dicts):
    keys = list_of_dicts[0].keys()
    return {key: [dic[key] for dic in list_of_dicts] for key in keys}


def infinite_loader(dataset, batch_size, seed=None):
    initial_seed = torch.seed()
    seed = torch.randint(0, 1000000, (1,)).item() if seed is None else seed
    g = torch.Generator()
    g.manual_seed(seed)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, generator=g)
    torch.manual_seed(initial_seed)
    # If batch_size is None, we want to yield the same data over and over again
    if batch_size == len(dataset):
        data = next(iter(loader))
        while True:
            yield data
    # Otherwise, we want to yield the data in the loader
    else:
        while True:
            for data in loader:
                yield data
