import numpy as np
import torch
from _reckernel import reckernel
from machina.utils import get_device

def sparsify(gp_net, epis, n_kq, use_thin=False, use_herd=False):
    kernel_matrix = np.ones((len(epis), len(epis)))
    vecs = [gp_net.input_feature(torch_episode(epi)).to(get_device()) for epi in epis]
    for i in range(len(epis)):
        kernel_matrix[i, i] = np.sum(gp_net.inner_product(vecs[i]).detach().cpu().numpy())
        for j in range(i):
            kernel_matrix[i, j] = kernel_matrix[j, i] = np.sum(gp_net.inner_product(vecs[i], vecs[j]).detach().cpu().numpy())
    idx, wei = matrix_recombination(kernel_matrix, n_kq, use_thin=use_thin, use_herd=use_herd)
    return idx, weight_convert(epis, idx, wei, gp_net.discount_grad, gp_net.gamma)

def gp_train(gp_net, optim_gp, traj, wei=None, cv_only=False, batch_size=32, epoch=10):
    raw_inputs = []
    raw_obj = []
    wei_gamma = []
    gp_obj = 'rews' if gp_net.rew_gp else 'advs'
    for epi in traj.current_epis:
        feat = gp_net.input_feature(torch_episode(epi)).detach()
        wei_gamma.append(gp_net.gamma_array[:len(feat)].view(-1))
        raw_inputs.append(feat)
        raw_obj.append(torch.from_numpy(epi[gp_obj]).detach())
    x = torch.cat(raw_inputs).to(get_device())
    y = torch.cat(raw_obj).to(get_device())
    w = torch.cat(wei_gamma).to(get_device())
    wei = w if wei is None else wei * w

    # old_mfs = gp_net.mean_func(x).view(-1).detach()

    idx_all = torch.randperm(len(x)).to(get_device())
    i_len = len(idx_all) // epoch

    for i in range(epoch):
        idx_mf = idx_all[i*i_len:(i+1)*i_len]
        idx_llh = idx_mf[:min(i_len, batch_size)]
        if cv_only:
            loss = 0
        else:
            loss = gp_net.likelihood(x[idx_llh], y[idx_llh])
        if gp_net.ctrl_var:
            idx = torch.randperm(len(x)).to(get_device())
            loss = loss + gp_monte_carlo(gp_net, x[idx_mf], y[idx_mf], wei[idx_mf])
        optim_gp.zero_grad()
        loss.backward()
        optim_gp.step()


def torch_episode(epi):
    # output: list of torch.Tensor
    device = get_device()
    obs = torch.from_numpy(epi['obs']).to(device)
    acs = torch.from_numpy(epi['acs']).to(device)
    return (obs, acs)

def matrix_recombination(kernel_matrix, target_size, use_thin=False, use_herd=False):
    # Given a Gram matrix, return the indices and weights after compression
    batch_size = kernel_matrix.shape[0]
    k, k_diag = set_kernel(kernel_matrix, use_thin=use_thin)
    
    rec_k = reckernel(kernel=k, kernel_diag=k_diag, use_thin=use_thin)
    idx_all = np.arange(batch_size)

    if use_thin or use_herd:
        return rec_k.subsample(idx_all, target_size)
    
    idx, weights = rec_k.recombination(idx_all, np.arange(batch_size), target_size)
    k_gram = kernel_matrix[:,idx][idx]
    k_memb = kernel_matrix[idx, :] @ np.ones((len(kernel_matrix), 1)) / len(kernel_matrix)
    weights = np.linalg.solve(k_gram, k_memb)
    weights = probify(weights)
    return idx, weights

def set_kernel(kernel_matrix, use_thin=False):
    if use_thin:
        def k(x, y):
            if len(x) == len(y):
                ret = np.empty(len(x))
                for i in range(len(x)):
                    ret[i] = kernel_matrix[x[i], y[i]].item()
                return ret
            else:
                return kernel_matrix[x,y]
    else:
        def k(x, y):
            if np.isscalar(x):
                x = np.array([x])
            if np.isscalar(y):
                y = np.array([y])
            return kernel_matrix[:,y][x]
    def k_diag(x):
        if np.isscalar(x):
            x = np.array([x])
        return np.diag(kernel_matrix)[x]
    
    return k, k_diag

def probify(init_weights):
    min_init = np.amin(init_weights)
    if min_init < 0:
        init_weights -= min_init
    init_weights /= np.sum(init_weights)
    return init_weights

def weight_convert(epis, idx, wei, discount_grad=True, gamma=1.):
    w_list = []
    wei = torch.from_numpy(wei)
    for j in range(len(wei)):
        i = idx[j]
        len_epi = len(epis[i]['acs'])
        w_j = wei[j] * torch.ones(len_epi)
        if discount_grad:
            w_j *= (gamma ** torch.arange(len_epi))
        w_list.append(w_j)
    return torch.cat(w_list).to(get_device())

def gp_monte_carlo(gp_net, sa_pair, rets, wei=None, clip_param=0.2, clip=False, old_mfs=None):
    """
    Montecarlo loss for V function.

    Parameters
    ----------
    vf : SVfunction
    batch : dict of torch.Tensor
    clip_param : float
    clip : bool

    Returns
    -------

    """
    vs = gp_net.mean_func(sa_pair).view(-1)

    out_masks = torch.ones_like(rets) if wei is None else wei

    vfloss1 = (vs - rets)**2
    if clip:
        vpredclipped = old_mfs + \
            torch.clamp(vs - old_mfs, -clip_param, clip_param)
        vfloss2 = (vpredclipped - rets)**2
        vf_loss = 0.5 * torch.sum(torch.max(vfloss1, vfloss2)
                                  * out_masks) / torch.sum(out_masks)
    else:
        vf_loss = 0.5 * torch.sum(vfloss1 * out_masks) / torch.sum(out_masks)
    
    return vf_loss