import numpy as np
import pandas as pd

import torch

def torch_groupby(X, C, N):
    #X: 1D data (float)
    #C: 1D labels (int)
    #N: number of labels
    result = torch.zeros(N).float().to(X.device)
    result.scatter_add_(0, C, X)
    return result

class CorrLenComputer():
    def __init__(self, xt, n_bins=50):
        #distances is on the same device as xt
        self.n_bins = n_bins
        self.distances = torch.sqrt((xt[:, :, None] - xt.T[None, :, :]).square().sum(1)).data
        #self.idx = torch.linspace(0,1,n_bins+1).to(xt.device)
        #self.dist_bins = torch.quantile(self.distances.flatten(), self.idx)
        self.dist_bins = torch.linspace(0,self.distances.max().item(),n_bins+1).to(xt.device)
        self.dist_idx = torch.bucketize(self.distances.flatten(), self.dist_bins)

        _ones = torch.ones(len(self.dist_idx)).to(self.dist_bins.device)
        self.bin_counts = torch_groupby(_ones, self.dist_idx, self.n_bins + 1)

    def _groupby_dist(self, K):
        sum_K = torch_groupby(K.flatten(), self.dist_idx, self.n_bins + 1)
        sum_K2 = torch_groupby(K.flatten() ** 2, self.dist_idx, self.n_bins + 1)

        mean_K = sum_K / self.bin_counts
        var_K = sum_K2 / self.bin_counts - mean_K ** 2

        return mean_K, var_K

    def fit_fort_params(self, ntk):
        _y = ntk.flatten()
        _ymean = _y.mean()

        _x = torch.sqrt(self.distances.flatten())
        _xmean = _x.mean()

        _dy = _y - _ymean
        _dx = _x - _xmean 

        _beta = ((_dy * _dx).sum() / (_dx * _dx).sum())
        _alpha = _ymean - _beta * _xmean
    
        return {
            r'$\xi_{FORT}$' : -(_alpha / _beta).item(),
            r'$\alpha_{FORT}$' : _alpha.item(),
            r'$\beta_{FORT}$' : _beta.item()
        }
    
    def fit_exp_params(self, x, y):
        c_inf = y[-10:].mean().item()
        scaled_mean_K = (y.cpu().numpy() - c_inf)/(y[0].item() - c_inf)
        k_cutoff = 0.5
        
        #Find FWHM
        df = pd.DataFrame({'time': x.cpu().numpy(), 'values': scaled_mean_K  })
        df.set_index('time', inplace=True)
        interpolated_df = df.interpolate(method='linear')
        interpolated_values = interpolated_df['values'].to_numpy()
        _T = df.index.values
        _X = interpolated_values
        
        diffs = np.diff(_X > k_cutoff)
        indices = np.where(diffs)[0][0]
        T_solutions = _T[indices] + (k_cutoff - _X[indices]) * (_T[indices + 1] - _T[indices]) / (_X[indices + 1] - _X[indices])
        FWHM = T_solutions
        xi = FWHM / (np.sqrt(2 * np.log(2)))#/2
        return {
            r'$C_{\infty}$' : c_inf,
            r'$\xi_{corr}$' : xi
        }
    

def sqrt_trace_covar_loss_grad(r, ntk):
    r2 = r ** 2
    return torch.dot(r2, torch.diag(ntk)).item() / ntk.shape[0]

def norm_mean_loss_grad(r, ntk):
    return torch.sqrt((r @ ntk @ r)).item() / ntk.shape[0]