from typing import Callable, Tuple, Iterable, List

import json
import torch
import numpy as np
from .digest_methods import model_weight_norm


def pearsonr(x, y):
    """
    Mimics `scipy.stats.pearsonr`

    Arguments
    ---------
    x : 1D torch.Tensor
    y : 1D torch.Tensor

    Returns
    -------
    r_val : float
        pearsonr correlation coefficient between x and y

    Scipy docs ref:
        https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.pearsonr.html

    Scipy code ref:
        https://github.com/scipy/scipy/blob/v0.19.0/scipy/stats/stats.py#L2975-L3033
    Example:
        >> x = np.random.randn(100)
        >> y = np.random.randn(100)
        >> sp_corr = scipy.stats.pearsonr(x, y)[0]
        >> th_corr = pearsonr(torch.from_numpy(x), torch.from_numpy(y))
        >> np.allclose(sp_corr, th_corr)
    """
    mean_x = torch.mean(x)
    mean_y = torch.mean(y)
    xm = x.sub(mean_x)
    ym = y.sub(mean_y)
    r_num = xm.dot(ym)
    r_den = torch.norm(xm, 2) * torch.norm(ym, 2)
    r_val = r_num / r_den
    return r_val


def corrcoef(x):
    """
    Mimics `np.corrcoef`

    Arguments
    ---------
    x : 2D torch.Tensor

    Returns
    -------
    c : torch.Tensor
        if x.size() = (5, 100), then return val will be of size (5,5)

    Numpy docs ref:
        https://docs.scipy.org/doc/numpy/reference/generated/numpy.corrcoef.html
    Numpy code ref:
        https://github.com/numpy/numpy/blob/v1.12.0/numpy/lib/function_base.py#L2933-L3013

    Example:
        >> x = np.random.randn(5,120)
        # result is a (5,5) matrix of correlations between rows
        >> np_corr = np.corrcoef(x)
        >> th_corr = corrcoef(torch.from_numpy(x))
        >> np.allclose(np_corr, th_corr.numpy())
        # [out]: True
    """
    # calculate covariance matrix of rows
    mean_x = torch.mean(x, 1)
    xm = x.sub(mean_x.expand_as(x))
    c = xm.mm(xm.t())
    c = c / (x.size(1) - 1)

    # normalize covariance matrix
    d = torch.diag(c)
    stddev = torch.pow(d, 0.5)
    c = c.div(stddev.expand_as(c))
    c = c.div(stddev.expand_as(c).t())

    # clamp between -1 and 1
    # probably not necessary but numpy does it
    c = torch.clamp(c, -1.0, 1.0)
    return c


def get_whole_dataset(dataloader, device):
    dataset = dataloader.dataset
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=len(dataset), shuffle=False)
    print(f"Length of dataset is {len(dataset)}")
    for batch_idx, (data, target) in enumerate(data_loader):
        data, target = data.to(device), target.to(device)
    return [data, target]


def optimal_path_information(model: torch.nn.Module, loss, dataloader, device, batch_selection_frequency,
                             **kwargs):
    """Save information relevant to optimal path computations such that it will be used later.

    Args:
        model: model
        data: datasets
        epoch: current epoch
        total_test_cost: Total cost from all datapoints (Currently not used)
        task_name: Task name to be used when saved

    Returns:None

    """
    epoch = kwargs["epoch"]
    task_name = kwargs["variable_name"]
    save_tk_info = kwargs.get("Save info", False)
    if epoch is None:
        return {}
    model_dict = model_weight_norm(model, None, dataloader, device, 800, **{"save_weights": False})
    if save_tk_info:
        data = get_whole_dataset(dataloader, device)[0]
        ntk_and_nabla_y_dict = ntk_and_nabla_y(model, data)
        model_dict.update(ntk_and_nabla_y_dict)
    output = {}
    for key in model_dict:
        output["tests>" + task_name + ">epoch " + str(epoch) + ">" + key] = float(model_dict[key])
    return output


def ntk_and_nabla_y(model: torch.nn.Module, data: torch.Tensor) -> dict:
    """
    Compute ntk and gradient of model output w.r.t the parameters : nabla_z model(dataspace)
    """
    model.zero_grad()
    D = len(data)
    n = sum(p.numel() for p in model.parameters())
    y_grad = torch.zeros(D, n, requires_grad=True)

    # Assemble individual projected gradients
    count = 0
    pred_norm_sum = 1
    # Loop over data points in batch
    for i in range(D):
        # Summing across selected batches
        y_grad[count], pred_norm = tk_loop_op(model, data, i, n)
        pred_norm_sum += pred_norm
        count += 1
    # Compute NTK
    tk = torch.mm(y_grad, y_grad.T)
    return {"tk_norm": float(torch.norm(tk)), "tk_trace": float(torch.trace(tk)),
            "y_grad_norm": float(torch.norm(y_grad))}


def tk_loop_op(model: torch.nn.Module, data: torch.Tensor, i: int, n: int) \
        -> Tuple[torch.tensor, float]:
    """
    Inner loop of tangent kernel computation. Computes the gradient of the model prediction w.r.t. parameters
    :param model:
    :param data:
    :param i:
    :return:
    """
    with torch.enable_grad():
        model.zero_grad()
        data_i = data.reshape((1,) + data.size())[:, i]
        pred_target_i = model(data_i)
        jacobi = torch.zeros((1, n))
        pred_grad_tuple = torch.autograd.grad(pred_target_i, model.parameters(), retain_graph=True, create_graph=True)
        pred_grad_vec = torch.cat([g.contiguous().view(-1) for g in pred_grad_tuple], dim=0).detach()

        jacobi[0] = pred_grad_vec
    return jacobi.cpu(), float(torch.norm(pred_target_i.detach()))


def dydtheta_loop_op(model: torch.nn.Module, data: torch.Tensor, z_dot: torch.Tensor, i: int) \
        -> Tuple[torch.tensor, float]:
    """
    Inner loop to calculate dy/dtheta and d(dy/dtheta)/dt for each datapoint and accumulate them in two matrices
    Args:
        model: model
        data: dataset
        z_dot: model parameters change w.r.t time
        i: datapoint

    Returns:

    """
    with torch.enable_grad():
        model.zero_grad()
        data_i = data.reshape((1,) + data.size())[:, i]
        pred_target_i = model(data_i)
        pred_grad_tuple = torch.autograd.grad(pred_target_i, model.parameters(), retain_graph=True, create_graph=True)
        pred_grad_vec = torch.cat([g.contiguous().view(-1) for g in pred_grad_tuple], dim=0)
        dy_grad_pred_tuple = torch.autograd.grad(pred_grad_vec, model.parameters(), grad_outputs=z_dot,
                                                 retain_graph=True)
        dy_grad_pred_vec = torch.cat([g.contiguous().view(-1) for g in dy_grad_pred_tuple], dim=0).detach()

    return dy_grad_pred_vec.cpu(), pred_grad_vec.cpu(), float(torch.norm(pred_target_i.detach()))


def dy_grad_dt(model: torch.nn.Module, data: torch.Tensor, dydtheta: torch.tensor,
               target: torch.Tensor) -> dict:
    """
    Calculate and return d(dy/dtheta)/dt
    Args:
        model: model
        data: dataset
        dydtheta: current calculated dy/dtheta
        target: target labels

    Returns: d(dy/dtheta)/dt matrix

    """
    D = dydtheta.size(0)
    n = dydtheta.size(1)
    dot_y_grad = torch.zeros(D, n, requires_grad=True)
    y_grad = torch.zeros(D, n, requires_grad=True)
    pred_target = model(data)
    error = pred_target - target.view(len(target), 1)
    dzdt = torch.flatten(torch.mm(dydtheta.T, error))
    count = 0
    pred_norm_sum = 1

    # Loop over data points in batch
    for i in range(D):
        # Summing across selected batches
        dot_y_grad[count], y_grad[count], pred_norm = dydtheta_loop_op(model, data, dzdt, i)
        pred_norm_sum += pred_norm
        count += 1
    return dot_y_grad
