import os
from sched import scheduler

import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.ticker import MaxNLocator
from torch.distributions import MixtureSameFamily, Normal, Categorical

# from data_config import save_path, mus, pis, d, sigmas, probability_flow
from flow import Flow, ScoreMatch, RectifiedFlow
from mog_util import schedule
from visualize import draw_grid, draw_line, draw_conditional_trajectory, draw_forward, draw_hist, line, draw_grid_line, \
    draw_x_line, draw_t_line, hist_3d, density_evolution_3d, draw_grid_with_plot


def draw_plot(rectified_flow, z0, z1, N=None):
    save_path = rectified_flow.config['save_path']
    assert isinstance(rectified_flow, Flow)
    if z0.shape[1] >= 2:
        traj = rectified_flow.sample_ode(z0=z0[:], N=N, batch_size=None)

        rectified_flow.config['probability_flow'] = False
        traj_sde = rectified_flow.sample_ode(z0=z0[:], N=1000, batch_size=None)
        rectified_flow.config['probability_flow'] = True

        # hist_3d(traj_sde[-1], traj[-1])

        M = 3
        plt.figure(figsize=(4, 4))
        plt.xlim(-M, M)
        plt.ylim(-M, M)
        plt.scatter(z1[:50000, 0].cpu().numpy(), z1[:50000, 1].cpu().numpy(), label=r'True Data', alpha=0.1, s=1)
        # plt.scatter(traj_sde[-1][:50000, 0].cpu().numpy(), traj_sde[-1][:50000, 1].cpu().numpy(), label='SDE Sampled Data', alpha=0.1, s=1)
        plt.scatter(traj[-1][:50000, 0].cpu().numpy(), traj[-1][:50000, 1].cpu().numpy(), label='ODE Sampled Data', alpha=0.1, s=1)
        # ma = estimate_marginal_accuracy(z1[:, :2], traj[-1][:, :2])

        # legend = plt.legend(fontsize=12, loc='upper right')
        # legend.get_frame().set_alpha(1)
        # # plt.title(f'Marginal Accuracy: {ma:.6f}')
        # for lh in legend.legendHandles:
        #     lh.set_alpha(1)

        plt.xlabel(r'$x$', fontsize=12)
        plt.ylabel(r'$y$', fontsize=12)
        plt.xticks(fontsize=12)
        plt.yticks(fontsize=12)
        plt.tight_layout()
        plt.savefig(f'{save_path}/{rectified_flow.name}/distribution.png', dpi=300)
        plt.show()
        plt.close()

        num_neighbour(z1, traj[-1])

        ANSD(z1[:50000, ], traj[-1][:50000, ],f'{save_path}/{rectified_flow.name}/ANSD.png')

        mas = []
        ts = [19, 39, 59, 79, 99]
        if N == 1000: ts = [199, 399, 599, 799, 999]
        for i in range(0, len(traj) - 1):
            t = torch.linspace(0, 1, len(traj)).to(z1.device)[i + 1]
            mean, std = rectified_flow.sde.marginal_prob(z1, t.repeat(z1.shape[0], 1))
            GT = mean + std * torch.randn_like(mean)
            ma = estimate_marginal_accuracy(GT[:, :2], traj[i + 1][:, :2], min=-3, max=3, num_bins=100)
            if i in ts:
                draw_hist([traj[i+1][:,0], GT[:,0]],['Generated $p_t$', 'Ground Truth $p_t$'],
                          f'{save_path}/{rectified_flow.name}/accumulated_error/{t.cpu().numpy():.3f}.png',
                          f'Marginal accuracy at {t:.2f}: {ma:.3f}')
            mas.append(ma)
        mas = torch.tensor(mas).to('cuda')
        line(mas, 'accumulated error', f'{save_path}/{rectified_flow.name}/accumulated_error')

    else:
        def preprocess_x(x):
            _x = torch.zeros(len(x), z0.shape[1]).to('cuda')
            _x[:, 0:1] = x.clone()
            return _x
        test_z0 = torch.randn(500000, z0.shape[1]).to(z0.device)
        traj = rectified_flow.sample_ode(z0=test_z0, N=N, use_gt=False)
        _gt_traj = rectified_flow.sample_ode(z0=test_z0, N=N, use_gt=True)
        num_neighbour(z1, traj[-1], distance_threshold=1e-4)
        draw_intermediate_marginal_distribution(rectified_flow, traj, None, z1, hist=True)


        error_covariance_path(rectified_flow, traj)

        marginal_accuracy = estimate_marginal_accuracy(z1, traj[-1], min=-3, max=3, num_bins=100)
        print(f'marginal accuracy: {marginal_accuracy:.6f}')



def error_covariance(rectified_flow):
    # x = torch.linspace(-2,2, 100).to('cuda').unsqueeze(-1)
    x = rectified_flow.config['samples_0']#[:5000]
    t_1 = torch.tensor([0]).to('cuda')
    velocity_1 = rectified_flow.predicted_velocity(x, t_1.unsqueeze(-1).repeat(len(x), 1))
    velocity_1 -= rectified_flow.velocity(x, t_1.unsqueeze(-1).repeat(len(x), 1))
    velocity_1 = velocity_1 - torch.mean(velocity_1)
    covariances = []
    for t in torch.linspace(0,0.6, 100).to('cuda'):
        velocity_t = rectified_flow.predicted_velocity(x, t.unsqueeze(-1).repeat(len(x)).unsqueeze(-1))
        velocity_t -= rectified_flow.velocity(x, t.unsqueeze(-1).repeat(len(x)).unsqueeze(-1))
        covariance = torch.mean(velocity_1 * (velocity_t-torch.mean(velocity_t)))
        # covariance = torch.mean(velocity_1 * velocity_t)
        covariances.append(covariance.detach().cpu().numpy().item())
    print(f'{covariances},')

def error_covariance_path(rectified_flow, traj):
    # x = torch.linspace(-2,2, 100).to('cuda').unsqueeze(-1)
    x = traj[0]
    t_1 = torch.tensor([0]).to('cuda')
    velocity_1 = rectified_flow.predicted_velocity(x, t_1.unsqueeze(-1).repeat(len(x), 1))
    velocity_1 -= rectified_flow.velocity(x, t_1.unsqueeze(-1).repeat(len(x), 1))
    velocity_1 = velocity_1 - torch.mean(velocity_1)
    covariances = []
    for i, t in enumerate(torch.linspace(0,0.4, 40).to('cuda')):
        x = traj[i + 1]
        velocity_t = rectified_flow.predicted_velocity(x, t.unsqueeze(-1).repeat(len(x)).unsqueeze(-1))
        velocity_t -= rectified_flow.velocity(x, t.unsqueeze(-1).repeat(len(x)).unsqueeze(-1))
        covariance = torch.mean(velocity_1 * (velocity_t-torch.mean(velocity_1)))
        # covariance = torch.mean(velocity_1 * velocity_t)
        covariances.append(covariance.detach().cpu().numpy())
    print(f'{np.mean(covariances)},')


def compute_tv_distance(p, q):
    # pq
    return 0.5 * torch.sum(torch.abs(p - q))

#TODO: using analytical histogram for ground truth.
def estimate_marginal_accuracy(samples_mu, samples_pi, min=-3, max=3,num_bins=100):
    d = samples_mu.shape[1]  # 
    tv_distances = []

    # TV
    for i in range(d):
        # 
        _min = min
        _max = max
        mu_hist = torch.histc(samples_mu[:, i], bins=num_bins, min=_min,
                              max=_max)
        pi_hist = torch.histc(samples_pi[:, i], bins=num_bins, min=_min,
                              max=_max)

        # 
        mu_hist /= mu_hist.sum()
        pi_hist /= pi_hist.sum()

        # TV
        tv_dist = compute_tv_distance(mu_hist, pi_hist)
        tv_distances.append(tv_dist)

    # 
    marginal_acc = 1 - torch.mean(torch.stack(tv_distances)) #/ 2
    # print(f'{marginal_acc:.6f}')
    return marginal_acc

def draw_intermediate_marginal_distribution(rectified_flow, traj, gt_traj, z1, hist=False):
    marginal_accuracy = estimate_marginal_accuracy(z1, traj[-1], min=-3, max=3, num_bins=100)
    save_path = rectified_flow.config['save_path']
    probability_flow = rectified_flow.config['probability_flow']
    # if isinstance(rectified_flow, ScoreMatch):
    mas = []
    for i in range(0, len(traj) - 1):
        if len(traj) == 101:
            ts =[0, 19, 39, 59, 79, 99]
        elif len(traj) == 1001:
            ts = [199, 399, 599, 799, 999]
        t = torch.linspace(0, 1, len(traj)).to('cuda')[i + 1]
        mean, std = rectified_flow.sde.marginal_prob(z1, t.repeat(z1.shape[0], 1))
        GT = mean + std * torch.randn_like(mean)
        # GT = gt_traj[i + 1]
        # if i == len(traj) - 2: print(GT)
        ma = estimate_marginal_accuracy(GT, traj[i + 1], min=-3, max=3, num_bins=100)
        if i in ts and hist:
        # if hist:
        #     draw_hist([traj[i+1], GT],['Generated', 'GT'],
        #               f'{save_path}/{rectified_flow.name}/accumulated_error/{t.cpu().numpy():.3f}.png',
        #               f'MA at t={t.cpu().numpy():.2f}: {ma:.3f} , Final MA:{marginal_accuracy:.3f}')
        #     draw_hist([GT], ['GT'],
        #               f'{save_path}/{rectified_flow.name}/accumulated_error/{t.cpu().numpy():.3f}.png',
        #               None)

            # x = torch.arange(-3, 3, 1e-2).to('cuda')
            # alpha_t, beta_t = schedule(t, rectified_flow.trajectory)
            # target_model = MixtureSameFamily(Categorical(rectified_flow.config['pis']),
            #                                  Normal(rectified_flow.config['mus'].squeeze(-1) * alpha_t,
            #                                         rectified_flow.config['sigmas'].squeeze(-1) * alpha_t + beta_t))
            # plt.title(rf'GT v.s. Generated {marginal_accuracy:.3f}')
            # plt.plot(x.cpu().numpy(), target_model.log_prob(x).exp().cpu().numpy(), color='red', label=f'Ground Truth Distribution')
            plt.figure(figsize=(6.4, 6.4))
            plt.hist(traj[i+1][:,:1].detach().cpu().numpy(), bins=200, density=True, alpha=0.3, range=(-3, 3), label=f'Sampled Data')
            # plt.hist(gt_traj[i+1][:,:1].detach().cpu().numpy(), bins=200, density=True, alpha=0.3, range=(-3, 3), label=f'Ground Truth Distribution')
            plt.hist(GT[:,:1].detach().cpu().numpy(), bins=200, density=True, alpha=0.3, range=(-3, 3), label=f'True Data')
            # plt.hist(z1[:,:1].detach().cpu().numpy(), bins=200, density=True, alpha=0.3, range=(-3, 3), label=f'Ground Truth Distribution')
            plt.legend(fontsize=20)
            plt.xticks(fontsize=25)
            plt.xlabel(r'$x_t$', fontsize=25)
            plt.ylabel("Density", fontsize=25)
            plt.yticks(fontsize=25)
            plt.tight_layout()
            plt.savefig(f'{save_path}/{rectified_flow.name}/GTvsGenerated_distribution_{t:.2f}.png', dpi=300)
            plt.show()
            plt.close()
            # if i in (79, 799):
            if i in (99, 999):
                ANSD(GT[:,:1][:, :1], traj[i+1][:50000][:, :1], f'{save_path}/{rectified_flow.name}/ANSD.png')

        # plt.hist(traj[-1].cpu().numpy(), bins=100, density=True, alpha=0.3, label=r'Generated', color='blue',)
        mas.append(ma)

    # plt.legend()
    # plt.savefig(f'{save_path}/{rectified_flow.name}/GT_distribution.png')
    # plt.show()
    # plt.close()

    mas = torch.tensor(mas).to('cuda')
    line(mas, 'accumulated error', f'{save_path}/{rectified_flow.name}/accumulated_error')
    mas = []
    return

def ode_samples_error(rectified_flow: Flow, N):
    test_z0 = torch.linspace(-1.5, 1.5, 1000)[:, None].to('cuda')
    traj = rectified_flow.sample_ode(z0=test_z0, N=N)
    gt_traj = rectified_flow.sample_ode(z0=test_z0, N=N, use_gt=True)
    plt.figure()
    plt.ylim(0, 2)
    plt.xlabel('z_0')
    plt.ylabel('ODE error')
    plt.plot(test_z0.to('cpu').numpy(), torch.abs(traj[-1] -gt_traj[-1]).to('cpu').numpy())
    save_path = rectified_flow.config['save_path']
    plt.savefig(f'{save_path}/{rectified_flow.name}/ODE samples error')
    plt.show()
    plt.close()

def ode_sensitivity(rectified_flow: Flow, N):
    eps = 1e-3
    test_z0 = torch.linspace(-1.5, 1.5, 10000)[:, None].to('cuda') -eps
    test_z0_add_epsilon = test_z0 + eps
    gt_traj = rectified_flow.sample_ode(z0=test_z0, N=N, use_gt=True)[-1]
    gt_traj_add_epsilon = rectified_flow.sample_ode(z0=test_z0_add_epsilon, N=N, use_gt=True)[-1]
    sensitvity = (gt_traj_add_epsilon - gt_traj) / (2 * eps)
    plt.figure()
    plt.xlabel('x_0')
    plt.ylim(0, 2)
    plt.ylabel('dx_1/dx_0')
    plt.plot(test_z0.to('cpu').numpy(), sensitvity.to('cpu').numpy())
    save_path = rectified_flow.config['save_path']
    plt.savefig(f'{save_path}/{rectified_flow.name}/sensitivity')
    plt.show()
    plt.close()

def emprirical_related(rectified_flow: Flow, z0, z1):
    save_path = rectified_flow.config['save_path']
    # Emrpirical risk estimation using mc is slower, you may not want to run it.
    def emprical_risk(x, t):
        z_t, risk = rectified_flow.empirical_risk(z0, z1, t[:1].repeat(z0.shape[0], 1))
        indices = torch.bucketize(z_t, x, right=True) - 1  # right=True 
        risk_avg = torch.zeros_like(x)
        for idx in range(len(x)):
            mask = indices == idx
            if mask.any():
                risk_avg[idx] = risk[mask].mean()
        return risk_avg


    if not os.path.exists(f'{save_path}/{rectified_flow.name}/risk.png'):
        draw_grid(emprical_risk, 'Empirical Risk', f'{save_path}/{rectified_flow.name}/risk', vmin=0, vmax=8)

    def empirical_target(x, t):
        if os.path.exists(f'{save_path}/{rectified_flow.name}/empirical_target.pt'):
            target_avg = torch.load(f'{save_path}/{rectified_flow.name}/empirical_target.pt')
            target_avg = target_avg[round(t[0][0].to('cpu').numpy() * 100), :]
            return target_avg

        z_t, risk = rectified_flow.empirical_target(z0, z1, t)
        indices = torch.bucketize(z_t, x, right=True) - 1  # right=True 
        target_avg = torch.zeros_like(x)
        for idx in range(len(x)):
            mask = indices == idx
            if mask.any():
                target_avg[idx] = risk[mask].mean()
        return target_avg
    draw_grid(empirical_target, 'Empirical_target',
              f'{save_path}/{rectified_flow.name}/empirical_target', vmin=-4, vmax=4,
              save_data = f'{save_path}/{rectified_flow.name}/empirical_target.pt')

    def empiricial_target_error(x_grid, t):
        empricial_target = rectified_flow.pred(x_grid[:, None], t)
        true = rectified_flow.gt(x_grid[:, None], t)
        p = rectified_flow.p(x_grid[:, None], t)
        error = torch.abs(empricial_target - true) * p
        return error.squeeze(-1)
    draw_grid(empiricial_target_error, 'Empirical_target_error', f'{save_path}/{rectified_flow.name}/empirical_target_error', vmin=0, vmax=1)

    from data_config import _slices
    for _slice in _slices:
        draw_line([
                   lambda t: rectified_flow.pred(torch.tensor(torch.ones_like(t) * _slice).to('cuda'), t),
                   lambda t: torch.load(f'{save_path}/{rectified_flow.name}/empirical_target.pt')[:, round((_slice + 5) * 100)],
                   lambda t: rectified_flow.gt(torch.tensor([[_slice]]).to('cuda'), t),
                   ],
                  f'Score or Flow Slice {_slice}',
                  f'{save_path}/{rectified_flow.name}/GT Score or Flow Slice {_slice}.png',
                  labels=[ 'Predicted Score', 'Empirical Target', 'GT'])

    for _t in [0.1, 0.3, 0.6, 0.8, 0.9, 0.95, 0.99]:
        draw_x_line([
                   lambda x: rectified_flow.pred(x, _t * torch.ones_like(x)),
                   lambda x: torch.load(f'{save_path}/{rectified_flow.name}/empirical_target.pt')[round(_t *100), :],
                   lambda x: rectified_flow.gt(x, _t * torch.ones_like(x)),
                   ],
                  f'Score or Flow x Slice {_t}',
                  f'{save_path}/{rectified_flow.name}/GT Score or Flow x Slice {_t}.png',
                  labels=['Predicted Score', 'Empirical Target', 'GT'])

def count_parameters(model):
    total_params = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            num_params = param.numel()
            total_params += num_params
    print(f"Parameters num: {total_params}")


def marginal_accuracy_along_bins(samples_mu, samples_pi, min=-3, max=3, verbose=True):
    mas = []
    bins = range(0, 11, 1)
    for bin in bins:
        bin = 10 * 2 ** bin
        ma = estimate_marginal_accuracy(samples_mu, samples_pi, min, max, num_bins=bin).detach().cpu().numpy()
        mas.append(ma)

    if verbose:
        plt.figure()
        plt.plot(bins, mas)
        plt.xlabel('number of bins 10*2^x')
        plt.ylabel('marginal accuracy')
        plt.grid(True)
        plt.title(sum(mas) / len(mas))
        # plt.legend()
        plt.show()
        plt.close()

    return sum(mas) / len(mas)


def show_hist_error_spectrum(samples_mu, samples_pi, min=-3, max=3, num_bins=500):
    # x = torch.linspace(0, 2 * torch.pi, 100)  # 02π100
    # y = torch.sin(x)  # 

    _min = min
    _max = max
    mu_hist = torch.histc(samples_mu, bins=num_bins, min=_min,
                          max=_max)
    pi_hist = torch.histc(samples_pi, bins=num_bins, min=_min,
                          max=_max)
    mu_hist /= mu_hist.sum()
    pi_hist /= pi_hist.sum()

    x = torch.linspace(min, max, num_bins + 1)[:num_bins]
    y = mu_hist - pi_hist
    #  y 
    y_fft = torch.fft.fft(y)

    # 
    freq = torch.fft.fftfreq(y.size(0), d=(x[1] - x[0]).item()).cpu().numpy()

    # Magnitude Spectrum
    magnitude_spectrum = torch.abs(y_fft).cpu().numpy()

    # 
    plt.figure(figsize=(8, 4))
    plt.plot(freq, magnitude_spectrum)
    plt.xlabel('Frequency')
    plt.ylabel('Magnitude')
    plt.ylim(0,0.2)
    plt.title('Fourier Transform - Magnitude Spectrum')
    plt.grid(True)
    plt.show()
    plt.close()

    plt.figure(figsize=(8, 4))
    plt.plot(x, y.cpu().numpy())
    plt.xlabel('x')
    plt.ylabel('Distribution Error')
    plt.title('Distribution Error')
    plt.grid(True)
    plt.show()
    plt.close()

def ANSD(data, generated_data, save_path):
    data = data[:20000]
    generated_data = generated_data[:20000]
    assert len(data) == len(generated_data)
    data_num = []
    samples_num = []
    thresholds = np.linspace(0.001,0.005,10,endpoint=True)
    # thresholds = np.linspace(0.01,0.05,10,endpoint=True)
    # thresholds = np.linspace(0.1,0.5,10,endpoint=True)
    # thresholds = np.linspace(0.0001,0.0005,10,endpoint=True)
    # thresholds = np.linspace(0.1,4,50,endpoint=True)
    # thresholds = np.linspace(0.001,1,50,endpoint=True)
    for threshold in thresholds:
        data = data.view(data.shape[0], -1)  # Now x_flattened has shape [20000, 36]
        distances = torch.cdist(data, data, p=2)  # Computes pairwise Euclidean distance
        within_distance_counts = (distances < threshold).sum(dim=1)  # Count for each data point
        # print(torch.mean(within_distance_counts.float()).cpu().numpy())
        data_num.append(torch.mean(within_distance_counts.float()).cpu().numpy())

        generated_data = generated_data.view(data.shape[0], -1)  # Now x_flattened has shape [20000, 36]
        distances = torch.cdist(generated_data, generated_data, p=2)  # Computes pairwise Euclidean distance
        within_distance_counts = (distances < threshold).sum(dim=1)  # Count for each data point
        # print(torch.mean(within_distance_counts.float()).cpu().numpy())
        samples_num.append(torch.mean(within_distance_counts.float()).cpu().numpy())

    plt.figure()
    plt.grid()
    plt.plot(thresholds, samples_num, label='Sampled Data')
    plt.plot(thresholds, data_num, label='True Data')
    plt.xticks(fontsize='15', rotation=45)
    plt.yticks(fontsize='15')
    plt.legend(fontsize='15')
    plt.xlabel('Distance Threshold', fontsize='20')
    plt.ylabel('ANSD', fontsize='20')
    plt.gca().xaxis.set_major_locator(MaxNLocator(integer=True))
    plt.tight_layout()
    plt.savefig(save_path)
    plt.show()
    plt.close()

    print(np.array(data_num).tolist())
    print(np.array(samples_num).tolist())

def draw_end_data_velocity(rectified_flow, return_values=False, vis=True, t=[0.0,0.8]):
    save_path = rectified_flow.config['save_path']
    def preprocess_x(x):
        _x = torch.zeros(len(x), rectified_flow.config['d']).to('cuda')
        _x[:, 0:1] = x.clone()
        return _x

    all_fuc_values = []
    for _t in torch.tensor(t).to('cuda'):
        fuc_values = draw_x_line([
            # lambda x: rectified_flow.predicted_velocity(x.unsqueeze(-1).expand(-1, z0.shape[1]), _t * torch.ones_like(x.unsqueeze(-1)))[:,:1],
            lambda x: rectified_flow.predicted_velocity(preprocess_x(x.unsqueeze(-1)),
                                                        _t * torch.ones_like(x.unsqueeze(-1), requires_grad=True))[:, :1],
            # lambda x: predicted_velocity(x, _t * torch.ones_like(x)),
            lambda x: rectified_flow.velocity(preprocess_x(x.unsqueeze(-1)),
                                              _t * torch.ones_like(x.unsqueeze(-1), requires_grad=True))[:, :1],
            # lambda x: rectified_flow.score(preprocess_x(x.unsqueeze(-1)),
            #                                   _t * torch.ones_like(x.unsqueeze(-1), requires_grad=True))[:, :1],
        ],
            f'Velocity x Slice {_t:.2f}',
            f'{save_path}/{rectified_flow.name}/velocity x Slice {_t:.2f}.png',
            ylim=None,
            labels=['Predicted Velocity', 'GT Velocity'], vis=vis)
        all_fuc_values.append(fuc_values)
    if return_values:
        easy_part, gt_easy_part = all_fuc_values[0][0], all_fuc_values[0][1]
        hard_part, gt_hard_part = all_fuc_values[1][0], all_fuc_values[1][1]
        return easy_part, gt_easy_part, hard_part,  gt_hard_part


def velocity_error(rectified_flow):
    save_path = rectified_flow.config['save_path']

    def preprocess_x(x):
        _x = torch.zeros(len(x), rectified_flow.config['d']).to('cuda')
        _x[:, 0:1] = x.clone()
        return _x

    def _velocity_error(x_grid, t):
        p = rectified_flow.p(preprocess_x(x_grid.unsqueeze(-1)), t)
        p = p / torch.sum(p)
        predicted_v = rectified_flow.predicted_velocity(
            preprocess_x(x_grid[:, None]), t)[:, :1]
        if 0 in t:
            predicted_v = predicted_v - torch.sum(p * predicted_v)
        error = (p * ((predicted_v
                       - rectified_flow.velocity(preprocess_x(x_grid[:, None]), t)[:, :1])))[:, 0]
        return error


    for _t in torch.tensor([0, 0.80]).to('cuda'):
        draw_x_line([
            lambda x: _velocity_error(x, _t * torch.ones_like(x.unsqueeze(-1), requires_grad=True)),
        ],
            f'Score x Slice {_t:.2f}',
            f'{save_path}/{rectified_flow.name}/core x Slice {_t:.2f}.png',
            ylim=None,
            mode='MAE',
            labels=['Predicted Velocity Error'],
        )


def hill_estimate(samples, k=None):
    samples = samples.detach().cpu().numpy()
    sorted_samples = np.sort(samples)[::-1]
    if k==None: k = sorted_samples.shape[0] - 1
    log_ratios = np.log(sorted_samples[:k]) - np.log(sorted_samples[k])
    hill_estimate = 1 / (np.mean(log_ratios))
    return hill_estimate

def tail_index_plot(data, generated_data):
    data = data[:20000]
    data = data
    generated_data = generated_data[:20000]
    generated_data = generated_data
    assert len(data) == len(generated_data)
    data = data.view(len(data), -1)  # Now x_flattened has shape [20000, 36]
    distances_data = torch.cdist(data, data, p=2) / (
            data.shape[1] ** 0.5)  # Computes pairwise Euclidean distance

    generated_data = generated_data.view(len(generated_data), -1)  # Now x_flattened has shape [20000, 36]
    distances_samples = torch.cdist(generated_data, generated_data, p=2) / (
            data.shape[1] ** 0.5)  # Computes pairwise Euclidean distance

    alpha_diff = []
    # distances = np.linspace(0.02, 0.1, 9)
    distances = np.linspace(0.01, 0.1, 10)
    for distance in distances:
        within_certain_distance_counts = (distances_samples < distance).sum(dim=1)  # Count for each data point
        within_certain_distance_counts_data = (distances_data < distance).sum(dim=1)  # Count for each data point
        alpha_sample = hill_estimate(within_certain_distance_counts)
        alpha_data = hill_estimate(within_certain_distance_counts_data)
        alpha_diff.append(alpha_data - alpha_sample)
    plt.plot(distances, alpha_diff)
    print(alpha_diff)
    plt.grid()
    # plt.savefig(f'{work_path}/Neighbours His_{ep}_dis{distance}_es{early_sde}.png')
    plt.show()
    plt.close()

def alpha_hill_estimate(data, eps):
    data = data.view(len(data), -1)  # Now x_flattened has shape [20000, 36]
    distances_data = torch.cdist(data, data, p=2) / (
            data.shape[1] ** 0.5)  # Computes pairwise Euclidean distance

    within_certain_distance_counts_data = (distances_data < eps).sum(dim=1)  # Count for each data point
    alpha = hill_estimate(within_certain_distance_counts_data)
    return alpha


def num_neighbour(data, generated_data,distance_threshold = 0.01
):
    data = data[:20000]
    data = data
    generated_data = generated_data[:20000]
    generated_data = generated_data
    assert len(data) == len(generated_data)
    data = data.view(len(data), -1)  # Now x_flattened has shape [20000, 36]
    distances_data = torch.cdist(data, data, p=2) / (
            data.shape[1] ** 0.5)  # Computes pairwise Euclidean distance

    generated_data = generated_data.view(len(generated_data), -1)  # Now x_flattened has shape [20000, 36]
    distances_samples = torch.cdist(generated_data, generated_data, p=2) / (
            data.shape[1] ** 0.5)  # Computes pairwise Euclidean distance

    # distance_threshold = 0.01
    within_certain_distance_counts = (distances_samples < distance_threshold).sum(dim=1)  # Count for each data point
    within_certain_distance_counts_data = (distances_data < distance_threshold).sum(dim=1)  # Count for each data point
    _max = np.max([torch.max(within_certain_distance_counts).cpu().numpy(),
                torch.max(within_certain_distance_counts_data).cpu().numpy()])
    # _max = 40
    # plt.hist(within_certain_distance_counts.float().detach().cpu().numpy(), bins=40, alpha=0.3, range=(1, _max),label= 'sample')
    # plt.hist(within_certain_distance_counts_data.float().detach().cpu().numpy(), bins=40,  alpha=0.3, range=(1, _max), label='data')

    sample_p = []
    data_p = []
    for i in range(11):
       sample_p.append((within_certain_distance_counts > i).sum().detach().cpu().numpy() / 20000)
       data_p.append((within_certain_distance_counts_data > i).sum().detach().cpu().numpy() / 20000)

    plt.bar(range(11), sample_p, label='Sampled Data', color='#80D462')
    plt.bar(range(11), data_p, label='True Data', color='#64BEE6')

    alpha_sample = hill_estimate(within_certain_distance_counts)
    alpha_data = hill_estimate(within_certain_distance_counts_data)
    print(alpha_data, alpha_sample)
    x = np.linspace(1, 10)
    plt.plot(x, x ** (-alpha_data) , label=r'$x^{-\alpha_{TD}}$', color ='red')
    x = np.linspace(1, 10)
    plt.plot(x, x ** (-alpha_sample), label=r'$x^{-\alpha_{ODE}}$', color ='green')

    plt.legend(fontsize=18)
    plt.yticks(fontsize=20)
    # plt.xticks(fontsize=15, rotation=45)
    plt.xticks(fontsize=20)
    # plt.ylabel('Probability', fontsize=15)
    plt.xlabel(r'x (# of Neighbours)', fontsize=20)
    # plt.title(rf'P(# of Neighbours>x)')
    plt.ylabel(rf'P(# of Neighbours>x)', fontsize=20)
    plt.tight_layout()
    plt.savefig(f'tail index with his.png', dpi=300)
    plt.show()
    plt.close()
