import os

import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from tqdm import tqdm

# from cifar.fid_util import evaluate_fid_score, get_activations
# from cifar.inception import InceptionV3
# from cifar_model import models
from sde_lib import VPSDE

devices = [7]
device_id = devices[0]
device = torch.device(f'cuda:{device_id}')  # change this if you don't have a gpu

sde = VPSDE()
flow = True
num_steps = 100
num_samples = 2000
batch_size = 4096

def tvd(p, q):
    p_hist = torch.histc(p, bins=100, min=0.0, max=1.0).to(device)
    q_hist = torch.histc(q, bins=100, min=0.0, max=1.0).to(device)
    p_hist = p_hist / torch.sum(p_hist)
    q_hist = q_hist / torch.sum(q_hist)
    tvd = 0.5 * torch.sum(torch.abs(p_hist - q_hist))
    return tvd.detach().cpu().numpy()


def nandivide(a, b):
    result = np.full_like(a, np.nan, dtype=np.float64)
    mask = (b != 0) & (~np.isnan(b))
    result[mask] = a[mask] / b[mask]
    return result


def hill_estimate(samples, k=None):
    samples = samples.detach().cpu().numpy()
    sorted_samples = np.sort(samples)[::-1]
    if k==None: k = sorted_samples.shape[0] - 1
    log_ratios = np.log(sorted_samples[:k]) - np.log(sorted_samples[k])
    hill_estimate = 1 / (np.mean(log_ratios))
    return hill_estimate

results = []



# prefixs = [f'dim-{i}-ds-{40000}-nr-{4}' for i in [8, 12, 16, 20, 24, 28, 32]]
# prefixs = [f'dim-{16}-ds-{20000}-nr-{4}']

# prefixs = [f'dim-{16}-ds-{30000}-nr-{4}-nf-{i}' for i in [32, 64]]
# prefixs.append(f'dim-{16}-ds-{30000}-nr-{4}')

# prefixs = [f'dim-{16}-ds-{40000}-nr-{i}' for i in [2, 4, 6]]

# prefixs = [f'dim-{16}-ds-{10000}-nr-{4}-nf-{i}' for i in [32, 64]]
# prefixs.append(f'dim-{16}-ds-{10000}-nr-{4}')
# prefixs = [f'dim-{16}-ds-{10000}-nr-{4}-nf-32']
prefixs = [f'dim-{32}-ds-{50000}-nr-{4}']


for model in prefixs:
# for model in ['smalldatasetsingledim',  'smalldatasetcrop', 'smalldataset', 'crop', 'singledim']:
    for ep in [3000]:
    # for ep in [2990, 2940, 2890, 2840]:
        for early_sde in [0]:
    #     for early_sde in [0,0.15,0.25, 0.5]:
            work_path = f'./{model}'
            os.makedirs(work_path, exist_ok=True)
            chkpt_name = f'cifar{model}DataParallel_lastep.pth'

            from cifar_model import get_model_and_dataset
            score_network, cifar_dset = get_model_and_dataset(model)
            score_network = torch.nn.DataParallel(score_network, device_ids=devices)
            stat_dict = torch.load(f'./{chkpt_name}', map_location=device)
            score_network.load_state_dict(stat_dict, strict=False)
            score_network = score_network.to(device)
            def score_fn(score_network):
                def _score_fn(x, t):
                    log_mean_coeff = -0.25 * t ** 2 * (20 - 0.1) - 0.5 * t * 0.1
                    std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
                    return (-score_network(x, t * 999)) / std[:, None, None, None]

                return _score_fn
            score_network = score_fn(score_network)

            print(len(cifar_dset), cifar_dset[0][0].size())
            channel_num = cifar_dset[0][0].size()[0]
            image_size = cifar_dset[0][0].size()[1]


            def generate_samples_1(score_network, num_samples, batch_size=None):
                rsde = sde.reverse(lambda x, t: score_network(x, t.unsqueeze(-1).expand(x.shape[0])), probability_flow=flow)
                srsde = sde.reverse(lambda x, t: score_network(x, t.unsqueeze(-1).expand(x.shape[0])), probability_flow=False)
                time_pts = torch.linspace(1, 0, num_steps, device=device)  # (ntime_pts,)

                if batch_size is None:
                    batch_size = num_samples

                num_batches = (num_samples + batch_size - 1) // batch_size
                generated_samples = []
                with torch.no_grad():
                    for batch_idx in range(num_batches):
                        current_batch_size = min(batch_size, num_samples - batch_idx * batch_size)
                        gaussian = torch.randn((current_batch_size, channel_num, image_size, image_size), device=device)
                        x_t = gaussian
                        pbar = tqdm(total=len(time_pts) - 1, leave=True)
                        for i in range(len(time_pts)-1):
                            # reverse sde # 25.95
                            t = time_pts[i] #.unsqueeze(-1)
                            dt = time_pts[i + 1] - t
                            if t > (1 - early_sde):
                                drift, diffusion = srsde.sde(x_t, t)
                            else:
                                drift, diffusion = rsde.sde(x_t, t)
                            # euler-maruyama step
                            x_t = x_t + drift * dt + diffusion * torch.randn_like(x_t) * torch.abs(dt) ** 0.5
                            pbar.update(1)
                        pbar.close()
                        generated_samples.append(x_t)
                return torch.cat(generated_samples, dim=0)

            if flow:
                save_path = f'./saved/ode'
            else:
                save_path = f'./saved/sde'
            if early_sde <= 0:
                path = f'{save_path}/{score_network.__class__.__name__}_ep{chkpt_name}_samples_ns{num_steps}_n{num_samples}_f{flow}.pt'
            else:
                path = f'{save_path}/{score_network.__class__.__name__}_ep{chkpt_name}_samples_ns{num_steps}_n{num_samples}_f{flow}_es{early_sde}.pt'
            print(path)
            if not os.path.exists(path):
            # if True:
                samples = generate_samples_1(score_network, num_samples, batch_size) #.reshape(-1, input_size[0], input_size[1])
                torch.save(samples, path)
            else:
                samples = torch.load(path)
            print(samples.shape)

            # images = samples.cpu().detach()
            # images = images.permute(0, 2, 3, 1)  #  [, , , ]
            # images = (images - images.min()) / (images.max() - images.min())  #  [0, 1] 
            # fig, axes = plt.subplots(3, 7, figsize=(14, 6))  #  3x7 
            # for i, ax in enumerate(axes.flat):
            #     if i < len(images):
            #         ax.imshow(images[i])  # 
            #         ax.axis('off')  # 
            #     else:
            #         ax.axis('off')  # 
            # plt.tight_layout()
            # plt.show()

            transforms = torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
            ])
            mnist_loader = torch.utils.data.DataLoader(cifar_dset, batch_size=num_samples, shuffle=False)
            # mnist_loader = torch.utils.data.DataLoader(cifar_dset, batch_size=50000, shuffle=False)
            data, _ = next(iter(mnist_loader))
            data = data.to(device)
            print(data.shape)

            # data_dis = []
            # sample_dis = []

            def ANSD(data, generated_data, vis=False):
                # data = data[:num_samples]
                # generated_data = generated_data[:num_samples]

                # data = data[:2000]
                # data = data #/ torch.std(data)
                # generated_data = generated_data[:2000]
                # generated_data = generated_data #/ torch.std(data).to(generated_data.device)

                # data = data[:2000]
                # data = data #/ torch.std(data)
                # data = torch.permute(data, (0,2, 3, 1))
                # generated_data = generated_data[:2000]
                # generated_data = generated_data #/ torch.std(data).to(generated_data.device)
                # generated_data = torch.permute(generated_data, (0,2, 3, 1))
                # from matplotlib import pyplot as plt
                # plt.imshow(torch.permute(data[0], (1, 2, 0)).cpu().detach().numpy())
                # plt.show()
                # plt.close()

                # import torch.nn.functional as F
                # data = F.interpolate(data, size=(12, 12), mode='nearest')
                # generated_data = F.interpolate(generated_data, size=(16, 16), mode='bicubic', align_corners=False)

                # from matplotlib import pyplot as plt
                # plt.imshow(torch.permute(data[0], (1, 2, 0)).cpu().detach().numpy())
                # plt.show()
                # plt.close()

                # assert len(data) == len(generated_data)
                data_num = []
                samples_num = []
                # thresholds = np.linspace(1, 10, 40, endpoint=True)
                # thresholds = np.linspace(1, 20, 40, endpoint=True)
                # thresholds = np.linspace(0, 1, 40, endpoint=True)
                # thresholds = np.linspace(1, 5, 40, endpoint=True)
                # thresholds = np.linspace(1, 5, 40, endpoint=True)
                thresholds = np.linspace(0, 1, 40, endpoint=True)
                # thresholds = np.linspace(0.1,0.5,10,endpoint=True)

                data = data.view(len(data), -1)  # Now x_flattened has shape [20000, 36]
                distances_data = torch.cdist(data, data, p=2) / (
                            data.shape[1] ** 0.5)  # Computes pairwise Euclidean distance

                generated_data = generated_data.view(len(generated_data), -1)  # Now x_flattened has shape [20000, 36]
                distances_samples = torch.cdist(generated_data, generated_data, p=2) / (
                        data.shape[1] ** 0.5)  # Computes pairwise Euclidean distance

                for threshold in thresholds:
                    within_distance_counts = (distances_data < threshold).sum(dim=1)  # Count for each data point
                    # print(torch.mean(within_distance_counts.float()).cpu().numpy())
                    data_num.append(torch.mean(within_distance_counts.float()).cpu().numpy())

                    within_distance_counts = (distances_samples < threshold).sum(dim=1)  # Count for each data point
                    # print(torch.mean(within_distance_counts.float()).cpu().numpy())
                    samples_num.append(torch.mean(within_distance_counts.float()).cpu().numpy())

                # distance_threshold = 0.2
                # within_certain_distance_counts = (distances_samples < distance_threshold).sum(dim=1) - 1 # Count for each data point
                # within_certain_distance_counts_data = (distances_data < distance_threshold).sum(dim=1) - 1 # Count for each data point
                # _max = np.max([torch.max(within_certain_distance_counts).cpu().numpy(),
                #             torch.max(within_certain_distance_counts_data).cpu().numpy()])
                # # _max = 400
                # plt.hist(within_certain_distance_counts.float().detach().cpu().numpy(), bins=50,  alpha=0.3, range=(1, _max),label= 'sample')
                # plt.hist(within_certain_distance_counts_data.float().detach().cpu().numpy(), bins=50,  alpha=0.3, range=(1, _max), label='data')
                # plt.legend()
                # # plt.ylim(0, 250)
                # plt.title(f'# of Neighbour within {distance_threshold}')
                # # plt.savefig(f'{work_path}/Distance_histogram_{ep}_{early_sde}.png')
                # plt.show()
                # plt.close()

                distance = 0.25
                for distance in [0.2]:
                # for distance in [0.15, 0.2, 0.25]:
                    within_certain_distance_counts = (distances_samples < distance).sum(dim=1)  # Count for each data point
                    within_certain_distance_counts_data = (distances_data < distance).sum(dim=1)  # Count for each data point
                    _max = np.max([torch.max(within_certain_distance_counts).cpu().numpy(),
                                torch.max(within_certain_distance_counts_data).cpu().numpy()])
                    plt.hist(within_certain_distance_counts.float().detach().cpu().numpy(), bins=50,  alpha=0.3, range=(2, _max),label= 'sample')
                    plt.hist(within_certain_distance_counts_data.float().detach().cpu().numpy(), bins=50,  alpha=0.3, range=(2, _max), label='data')
                    plt.legend()
                    alpha_sample = hill_estimate(within_certain_distance_counts)
                    alpha_data = hill_estimate(within_certain_distance_counts_data)
                    plt.title(f'alpha_sample={alpha_sample:.3f}, alpha_data={alpha_data:.3f}, alpha_diff={alpha_data-alpha_sample:.3f}')
                    alpha_diff = alpha_data - alpha_sample

                    print('data',alpha_data, 'samples', alpha_sample)
                    # plt.savefig(f'{work_path}/Neighbours His_{ep}_dis{distance}_es{early_sde}.png')
                    # plt.show()
                    plt.close()

                alpha_diffs = []
                # distances = np.linspace(0.1, 0.4, 10)
                distances = np.linspace(0.2, 0.28, 10)
                for distance in distances:
                    within_certain_distance_counts = (distances_samples < distance).sum(dim=1)  # Count for each data point
                    within_certain_distance_counts_data = (distances_data < distance).sum(dim=1)  # Count for each data point
                    alpha_sample = hill_estimate(within_certain_distance_counts)
                    alpha_data = hill_estimate(within_certain_distance_counts_data)
                    alpha_diffs.append(alpha_data - alpha_sample)
                print(alpha_diffs)
                plt.plot(distances, alpha_diffs)
                plt.grid()
                # plt.savefig(f'{work_path}/Neighbours His_{ep}_dis{distance}_es{early_sde}.png')
                plt.show()
                plt.close()

                if vis:
                    plt.figure()
                    plt.grid()
                    plt.plot(thresholds, np.divide(samples_num, data_num), label='Sampled/True Data')
                    plt.legend(fontsize='15')
                    plt.xlabel('Distance Threshold', fontsize='20')
                    plt.ylabel('ANSD Ratio', fontsize='20')
                    plt.xticks(fontsize='15', rotation=45)
                    plt.yticks(fontsize='15')
                    plt.tight_layout()
                    plt.savefig(f'{work_path}/ANSD_ratio_{ep}_{early_sde}.png')
                    plt.show()
                    plt.close()

                    plt.figure()
                    plt.grid()
                    plt.plot(thresholds, samples_num, label='Sampled Data')
                    plt.plot(thresholds, data_num, label='True Data')
                    plt.legend(fontsize='15')
                    plt.xlabel('Distance Threshold', fontsize='20')
                    plt.ylabel('ANSD', fontsize='20')
                    plt.xticks(fontsize='15', rotation=45)
                    plt.yticks(fontsize='15')
                    plt.tight_layout()
                    plt.savefig(f'{work_path}/ANSD_{ep}_{early_sde}.png')
                    plt.show()
                    plt.close()

                    range = (0, 1)
                    distances_data = distances_data.reshape(-1)
                    distances_samples = distances_samples.reshape(-1)
                    plt.hist(distances_data.detach().cpu().numpy(), bins=100, label='data', alpha=0.3, range=range)
                    plt.hist(distances_samples.detach().cpu().numpy(), bins=100, label='sampled', alpha=0.3,
                             range=range)
                    plt.legend()
                    plt.title(f'Distance Histogram')
                    plt.savefig(f'{work_path}/Distance_histogram_{ep}_{early_sde}.png')
                    plt.show()
                    plt.close()

                print('tvd:')
                _tvd = tvd(distances_data, distances_samples)
                print(_tvd)
                print('max ansd ratio:')
                max_ansd_ratio = np.nanmax(nandivide(np.array(samples_num), np.array(data_num)))
                print(max_ansd_ratio)
                print('mean distrance shift')
                mean_distance_shift = (torch.mean(distances_data.to(device)) - torch.mean(distances_samples.to(device))).detach().cpu().numpy()
                print(mean_distance_shift)
                print('ANSD shift')
                ansd_shift =np.mean(np.abs(np.array(data_num) - np.array(samples_num)))
                print(ansd_shift)
                # return _tvd, max_ansd_ratio, mean_distance_shift, ansd_shift
                return alpha_diff

            result = ANSD(data, samples)
            results.append(result)

print(results)
# results = np.array(results).T
#
# for result in results:
#     plt.figure()
#     plt.grid()
#     plt.plot(result)
#     plt.legend(fontsize='15')
#     # plt.xlabel('Distance Threshold', fontsize='20')
#     # plt.ylabel('ANSD Ratio', fontsize='20')
#     # plt.xticks(fontsize='15', rotation=45)
#     # plt.yticks(fontsize='15')
#     plt.tight_layout()
#     # plt.savefig(f'{work_path}/ANSD_ratio_{ep}_{early_sde}.png')
#     plt.show()
#     plt.close()
#
