import os

import numpy as np
import torch
import torchvision
from matplotlib import pyplot as plt
from tqdm import tqdm

from cifar.fid_util import evaluate_fid_score, get_activations
from cifar.inception import InceptionV3
# from cifar.fid_util import evaluate_fid_score, get_activations
# from cifar.inception import InceptionV3
# from cifar_model import models
from sde_lib import VPSDE
from visualize import draw_hist

devices = [7]
device_id = devices[0]
device = torch.device(f'cuda:{device_id}')  # change this if you don't have a gpu

sde = VPSDE()
# flow = True
# num_steps = 100
flow = False
num_steps = 1000
num_samples = 2000
batch_size = 4096
channel=64

def tvd(p, q):
    p_hist = torch.histc(p, bins=100, min=0.0, max=1.0).to(device)
    q_hist = torch.histc(q, bins=100, min=0.0, max=1.0).to(device)
    p_hist = p_hist / torch.sum(p_hist)
    q_hist = q_hist / torch.sum(q_hist)
    tvd = 0.5 * torch.sum(torch.abs(p_hist - q_hist))
    return tvd.detach().cpu().numpy()


def nandivide(a, b):
    result = np.full_like(a, np.nan, dtype=np.float64)
    mask = (b != 0) & (~np.isnan(b))
    result[mask] = a[mask] / b[mask]
    return result


def hill_estimate(samples, k=None):
    samples = samples.detach().cpu().numpy()
    sorted_samples = np.sort(samples)[::-1]
    if k==None: k = sorted_samples.shape[0] - 1
    log_ratios = np.log(sorted_samples[:k]) - np.log(sorted_samples[k])
    hill_estimate = 1 / (np.mean(log_ratios))
    return hill_estimate

results = []



# prefixs = [f'dim-{i}-ds-{40000}-nr-{4}' for i in [8, 12, 16, 20, 24, 28, 32]]
# prefixs = [f'dim-{16}-ds-{20000}-nr-{4}']

# prefixs = [f'dim-{16}-ds-{30000}-nr-{4}-nf-{i}' for i in [32, 64]]
# prefixs.append(f'dim-{16}-ds-{30000}-nr-{4}')

# prefixs = [f'dim-{16}-ds-{40000}-nr-{i}' for i in [2, 4, 6]]

# prefixs = [f'dim-{16}-ds-{10000}-nr-{4}-nf-{i}' for i in [32, 64]]
# prefixs.append(f'dim-{16}-ds-{10000}-nr-{4}')
# prefixs = [f'dim-{16}-ds-{10000}-nr-{4}-nf-32']
prefixs = [f'dim-{16}-ds-{10000}-nr-{4}']

for model in prefixs:
    # for model in ['smalldatasetsingledim',  'smalldatasetcrop', 'smalldataset', 'crop', 'singledim']:
    for ep in [3000]:
        # for ep in [2990, 2940, 2890, 2840]:
        for early_sde in [0]:
            #     for early_sde in [0,0.15,0.25, 0.5]:
            work_path = f'./{model}'
            os.makedirs(work_path, exist_ok=True)
            chkpt_name = f'cifar{model}DataParallel_lastep.pth'

            from cifar_model import get_model_and_dataset
            score_network, cifar_dset = get_model_and_dataset(model)
            score_network = torch.nn.DataParallel(score_network, device_ids=devices)
            stat_dict = torch.load(f'./{chkpt_name}', map_location=device)
            score_network.load_state_dict(stat_dict, strict=False)
            score_network = score_network.to(device)
            def score_fn(score_network):
                def _score_fn(x, t):
                    log_mean_coeff = -0.25 * t ** 2 * (20 - 0.1) - 0.5 * t * 0.1
                    std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
                    return (-score_network(x, t * 999)) / std[:, None, None, None]

                return _score_fn
            score_network = score_fn(score_network)

            print(len(cifar_dset), cifar_dset[0][0].size())
            channel_num = cifar_dset[0][0].size()[0]
            image_size = cifar_dset[0][0].size()[1]


            def generate_samples_1(score_network, num_samples, batch_size=None):
                rsde = sde.reverse(lambda x, t: score_network(x, t.unsqueeze(-1).expand(x.shape[0])), probability_flow=flow)
                srsde = sde.reverse(lambda x, t: score_network(x, t.unsqueeze(-1).expand(x.shape[0])), probability_flow=False)
                time_pts = torch.linspace(1, 0, num_steps, device=device)  # (ntime_pts,)

                if batch_size is None:
                    batch_size = num_samples

                num_batches = (num_samples + batch_size - 1) // batch_size
                generated_samples = []
                with torch.no_grad():
                    for batch_idx in range(num_batches):
                        current_batch_size = min(batch_size, num_samples - batch_idx * batch_size)
                        gaussian = torch.randn((current_batch_size, channel_num, image_size, image_size), device=device)
                        x_t = gaussian
                        pbar = tqdm(total=len(time_pts) - 1, leave=True)
                        for i in range(len(time_pts)-1):
                            # reverse sde # 25.95
                            t = time_pts[i] #.unsqueeze(-1)
                            dt = time_pts[i + 1] - t
                            if t > (1 - early_sde):
                                drift, diffusion = srsde.sde(x_t, t)
                            else:
                                drift, diffusion = rsde.sde(x_t, t)
                            # euler-maruyama step
                            x_t = x_t + drift * dt + diffusion * torch.randn_like(x_t) * torch.abs(dt) ** 0.5
                            pbar.update(1)
                        pbar.close()
                        generated_samples.append(x_t)
                return torch.cat(generated_samples, dim=0)

            if flow:
                save_path = f'./saved/ode'
            else:
                save_path = f'./saved/sde'
            if early_sde <= 0:
                path = f'{save_path}/{score_network.__class__.__name__}_ep{chkpt_name}_samples_ns{num_steps}_n{num_samples}_f{flow}.pt'
            else:
                path = f'{save_path}/{score_network.__class__.__name__}_ep{chkpt_name}_samples_ns{num_steps}_n{num_samples}_f{flow}_es{early_sde}.pt'
            print(path)
            if not os.path.exists(path):
                # if True:
                samples = generate_samples_1(score_network, num_samples, batch_size) #.reshape(-1, input_size[0], input_size[1])
                torch.save(samples, path)
            else:
                samples = torch.load(path, map_location=device)
            print(samples.shape)

            # images = samples.cpu().detach()
            # images = images.permute(0, 2, 3, 1)  #  [, , , ]
            # images = (images - images.min()) / (images.max() - images.min())  #  [0, 1] 
            # fig, axes = plt.subplots(3, 7, figsize=(14, 6))  #  3x7 
            # for i, ax in enumerate(axes.flat):
            #     if i < len(images):
            #         ax.imshow(images[i])  # 
            #         ax.axis('off')  # 
            #     else:
            #         ax.axis('off')  # 
            # plt.tight_layout()
            # plt.show()

            transforms = torchvision.transforms.Compose([
                torchvision.transforms.ToTensor(),
            ])
            mnist_loader = torch.utils.data.DataLoader(cifar_dset, batch_size=num_samples, shuffle=False)
            data, _ = next(iter(mnist_loader))
            data = data.to(device)
            print(data.shape)

            fid = evaluate_fid_score(data, samples, 64, batch_size=100, device=device)
            print(fid)
            # quit()
            def get_act(samples, batch_size=1000, dim=2048):
                block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dim]
                model = InceptionV3([block_idx], resize_input=True, normalize_input=True).to(device)
                return get_activations(samples, model, batch_size, dims=dim, device=device)


            data_act = get_act(data, batch_size=1000, dim=channel)
            samples_act = get_act(samples, batch_size=1000, dim=channel)

            smaller_rate = []
            # for feature_index in range(10):
            # for feature_index in [1,19,23,25,61]:
            for feature_index in range(channel):
                _samples_act = torch.tensor(samples_act[:,feature_index]).to(device)
                _data_act = torch.tensor(data_act[:,feature_index]).to(device)
                data_var = torch.var(_data_act)
                sample_var = torch.var(_samples_act)
                # print(data_var < sample_var)
                if data_var < sample_var:
                    smaller_rate.append(1)
                else:
                    smaller_rate.append(0)

                draw_hist([_samples_act, _data_act], ['SDE Sampled', 'Train Data'],
                          f'{save_path}/feature density_{feature_index}.png',
                          f'feature density_{feature_index} datav: {data_var:.4f} '
                          f'samplev: {sample_var:.4f}',
                          range=None,
                          # range=(0,2),
                          )

            print(np.mean(smaller_rate))
