import logging
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import datasets
from PIL import Image
import glob
from sklearn.datasets import make_circles, make_moons, make_blobs, make_swiss_roll



def get_batch(num_samples, dataset="circle", noise = 0.05, device = "cpu", return_label = False):
    # sample points from 2-circle distribution
    # points: np[n_samples, 2]
    if dataset == "circle":
        points, labels = make_circles(n_samples=num_samples, noise=0.06, factor=0.5)
    elif dataset == "moons":
        points, labels = make_moons(n_samples=num_samples, noise=.05)
    elif dataset == "blobs":
        # three center default
        points, labels = make_blobs(n_samples=num_samples, random_state=8)
    elif dataset == "8blobs":
        centers = [[2, 0], [-2, 0], [0, -2], [0, 2], [1.41, 1.41], [-1.41, -1.41], [-1.41, 1.41], [1.41, -1.41]]
        points, labels = make_blobs(n_samples=num_samples, centers = centers, random_state=8, cluster_std=0.1)
    elif dataset == "new8blobs":
        # to follow common definition 
        scale = 4.
        centers = [(1, 0), (-1, 0), (0, 1), (0, -1), (1. / np.sqrt(2), 1. / np.sqrt(2)),
                (1. / np.sqrt(2), -1. / np.sqrt(2)), (-1. / np.sqrt(2),
                                                        1. / np.sqrt(2)), (-1. / np.sqrt(2), -1. / np.sqrt(2))]
        centers = [(scale * x, scale * y) for x, y in centers]
        points, labels = make_blobs(n_samples=30000, centers = centers, random_state=8, cluster_std=0.5)
        points = points/1.41
    
    elif dataset == "pinwheel":
        radial_std = 0.3
        tangential_std = 0.1
        num_classes = 5
        num_per_class = num_samples // 5
        rate = 0.25
        rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False)
        rng = rng = np.random.RandomState()

        features = rng.randn(num_classes*num_per_class, 2) \
            * np.array([radial_std, tangential_std])
        features[:, 0] += 1.
        labels = np.repeat(np.arange(num_classes), num_per_class)

        angles = rads[labels] + rate * np.exp(features[:, 0])
        rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)])
        rotations = np.reshape(rotations.T, (-1, 2, 2))

        points = 2 * rng.permutation(np.einsum("ti,tij->tj", features, rotations))

    elif dataset == "swiss-rolls":
        _points, labels = make_swiss_roll(n_samples=num_samples, noise=noise)
        points = np.stack([_points[:, 0], _points[:, 2]], axis=1)
        return_label = False
    elif dataset == "checkerboard":
        x1 = np.random.rand(num_samples) * 4 - 2
        _x2 = np.random.rand(num_samples) - np.random.randint(0, 2, num_samples) * 2
        x2 = _x2 + (np.floor(x1) % 2)
        points = np.concatenate([x1[:, None], x2[:, None]], 1) * 2
        return_label = False
    elif dataset == "spiral2d":
        n = np.sqrt(np.random.rand(num_samples // 2, 1)) * 540 * (2 * np.pi) / 360
        d1x = -np.cos(n) * n + np.random.rand(num_samples // 2, 1) * 0.5
        d1y = np.sin(n) * n + np.random.rand(num_samples // 2, 1) * 0.5
        x = np.vstack((np.hstack((d1x, d1y)), np.hstack((-d1x, -d1y)))) / 3
        points = x + np.random.randn(*x.shape) * 0.1
        return_label = False
    elif dataset == "olympics":
        w = 3.5
        h = 1.5
        def circle_generate_sample(N, noise=0.25):
            angle = np.random.uniform(high=2 * np.pi, size=N)
            random_noise = np.random.normal(scale=np.sqrt(0.2), size=(N, 2))
            pos = np.concatenate([np.cos(angle), np.sin(angle)])

            return np.stack([np.cos(angle), np.sin(angle)], axis=1) + noise * random_noise
        centers = np.array([[-w, h], [0.0, h], [w, h], [-w * 0.6, -h], [w * 0.6, -h]])
        pos = [circle_generate_sample(num_samples // 5 + 1, noise) + centers[i : i + 1] / 2 for i in range(5)]
        points = np.concatenate(pos)[:num_samples]
        return_label = False
    
    x = torch.tensor(points).type(torch.float32).to(device)

    if return_label:
        return (x, labels)
    else:
        return x

VIZ_SAMPLES = 30000
VIZ_TIMESTEPS = 41


def visualize_cnf(model, results_dir, dataset, p_z0, t0, t1, device):
    from torchdiffeq import odeint
    viz_samples = VIZ_SAMPLES
    viz_timesteps = VIZ_TIMESTEPS
    target_sample = get_batch(viz_samples, dataset, device)

    # results_dir = os.path.join(args.results_dir, args.dataset)
    if not os.path.exists(results_dir):
        os.makedirs(results_dir)
    with torch.no_grad():

        # Generate grid on base density
        x = np.linspace(-2.5, 2.5, 100)
        y = np.linspace(-2.5, 2.5, 100)
        points = np.vstack(np.meshgrid(x, y)).reshape([2, -1]).T

        x = torch.tensor(points).type(torch.float32).to(device)
        logp_est_t0 = p_z0.log_prob(x).to(device).view(x.shape[0], 1)

        # forward simulate
        z_t_density, logp_est_t = odeint(
            model.forward_simulate,
            (x, logp_est_t0),
            torch.linspace(t0, t1, viz_timesteps).to(device),
            atol=1e-5,
            rtol=1e-5,
            method='dopri5',
        )

        data_samples = get_batch(viz_samples, dataset, device)
        samples_t = odeint(
            model.predict, 
            data_samples, 
            torch.linspace(t0, t1, viz_timesteps).to(device),
            atol=1e-5,
            rtol=1e-5,
            method='dopri5',
        )

        # Create plots for each timestep
        for (t, z_density, logp_est, samples) in zip(
                np.linspace(t0, t1, viz_timesteps),
                z_t_density, logp_est_t, samples_t):

            fig = plt.figure(figsize=(12, 4), dpi=200)
            plt.tight_layout()
            plt.axis('off')
            plt.margins(0, 0)
            fig.suptitle(f'{t:.2f}s')

            ax1 = fig.add_subplot(1, 3, 1)
            ax1.set_title('Target')
            ax1.get_xaxis().set_ticks([])
            ax1.get_yaxis().set_ticks([])
            ax2 = fig.add_subplot(1, 3, 2)
            ax2.set_title('Samples and field distortion')
            ax2.get_xaxis().set_ticks([])
            ax2.get_yaxis().set_ticks([])
            ax3 = fig.add_subplot(1, 3, 3)
            ax3.set_title('Probability')
            ax3.get_xaxis().set_ticks([])
            ax3.get_yaxis().set_ticks([])

            ax1.hist2d(*target_sample.detach().cpu().numpy().T, bins=300, density=True,
                    range=[[-2.5, 2.5], [-2.5, 2.5]])

            # z_density shows how field distored by flow
            ax2.scatter(z_density.detach().cpu().numpy()[:,0], \
                            z_density.detach().cpu().numpy()[:,1], s=1)
            ax2.scatter(samples.detach().cpu().numpy()[:, 0], 
                            samples.detach().cpu().numpy()[:, 1], s=1, marker='x')
            ax3.tricontourf(*x.detach().cpu().numpy().T,
                            np.exp(logp_est.view(-1).detach().cpu().numpy()), 200)

            plt.savefig(os.path.join(results_dir, f"cnf-viz-{int(t*1000):05d}.jpg"),
                    pad_inches=0.2, bbox_inches='tight')
            plt.close()
            print('Save ', f"cnf-viz-{int(t*1000):05d}.jpg")

        # generate gif
        img, *imgs = [Image.open(f) for f in sorted(glob.glob(os.path.join(results_dir, f"cnf-viz-*.jpg")))]
        img.save(fp=os.path.join(results_dir, f"cnf-viz.gif"), format='GIF', append_images=imgs,
                save_all=True, duration=250, loop=0)

    print('Saved visualization animation at {}'.format(os.path.join(results_dir, f"cnf-viz.gif")))




def batch_iter(X, batch_size, shuffle=False):
    """
    X: feature tensor (shape: num_instances x num_features)
    """
    if shuffle:
        idxs = torch.randperm(X.shape[0])
    else:
        idxs = torch.arange(X.shape[0])
    if X.is_cuda:
        idxs = idxs.cuda()
    for batch_idxs in idxs.split(batch_size):
        yield X[batch_idxs]

import datasets
def load_data(name):

    if name == 'bsds300':
        return datasets.BSDS300()

    elif name == 'power':
        return datasets.POWER()

    elif name == 'gas':
        return datasets.GAS()

    elif name == 'hepmass':
        return datasets.HEPMASS()

    elif name == 'miniboone':
        return datasets.MINIBOONE()

    else:
        raise ValueError('Unknown dataset')

@torch.no_grad()
def _ascent_monotonically(x):
    """ check whether tensor increase monotonically
    """
    for ii in range(len(x)-1):
        if x[ii] < x[ii+1]:
            continue
        else:
            return False
    return True

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def count_nfe(model):
    return dict(model.named_buffers())['_num_evals'].item()


def get_logger(logpath='', filepath='', package_files=[],
			   displaying=True, saving=True, debug=False):
	logger = logging.getLogger()
	if debug:
		level = logging.DEBUG
	else:
		level = logging.INFO
	logger.setLevel(level)
	# create formatter
	formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s', datefmt='%m-%d %H:%M')
	if saving:
		# change mode to default "a", continue to write on file instea of replacing
		log_file_handler = logging.FileHandler(logpath, mode='a')
		log_file_handler.setLevel(level)
		log_file_handler.setFormatter(formatter)
		logger.addHandler(log_file_handler)
	if displaying:
		console_handler = logging.StreamHandler()
		console_handler.setLevel(level)
		console_handler.setFormatter(formatter)
		logger.addHandler(console_handler)
	logger.info("Code {} running with following command input".format(filepath))

	for f in package_files:
		logger.info(f)
		with open(f, 'r') as package_f:
			logger.info(package_f.read())

	return logger


def collect_input_command(argv):
    input_command = argv
    ind = [i for i in range(len(input_command)) if input_command[i] == "--load"]
    if len(ind) == 1:
        ind = ind[0]
        input_command = input_command[:ind] + input_command[(ind+2):]
    input_command = " ".join(input_command)
    return input_command

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

class RunningAverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, momentum=0.99):
        self.momentum = momentum
        self.reset()

    def reset(self):
        self.val = None
        self.avg = 0

    def update(self, val):
        if self.val is None:
            self.avg = val
        else:
            self.avg = self.avg * self.momentum + val * (1 - self.momentum)
        self.val = val


