import numpy as np
import scipy
import tqdm
from scipy.integrate import solve_ivp
from scipy.spatial import KDTree
from scipy.stats import binned_statistic_dd, binned_statistic

from pdPINN.util.data_containers import VectorfieldFuns

import pandas as pd
from typing import Tuple, Optional, Callable, List
from scipy.stats import multivariate_t
from tqdm import tqdm


def sample_particle_clusters(number_of_swarms=20,
                             avg_number_of_flocks=30,
                             birds_per_flock=10,
                             xmin: int = -4, xmax: int = 4, ymin: int = -4, ymax: int = 4):
    """
    Hierarchical sampling of particles.

    For each swarm, sample Poisson(avg_number_of_flocks), with each flock having birds_per_flock particles.


    Args:
        number_of_swarms ():
        avg_number_of_flocks ():
        birds_per_flock ():
        xmin ():
        xmax ():
        ymin ():
        ymax ():

    Returns:

    """
    x_diff = xmax - xmin
    y_diff = ymax - ymin

    swarm_noise = np.random.normal(scale=0.3, size=(number_of_swarms, 2))
    swarm_centers = np.random.uniform(0.1, 0.9, (number_of_swarms, 2)) * 2 - 1 + swarm_noise
    swarm_centers += 1

    num_birds_sampler = scipy.stats.poisson(birds_per_flock)

    flock_centers = []
    for swarm_center in swarm_centers:
        num_flocks_sampler = scipy.stats.poisson(
            np.clip(np.random.normal(avg_number_of_flocks, avg_number_of_flocks), 1, np.inf))
        flock_center_dist = scipy.stats.norm(swarm_center, 0.1 * np.ones_like(swarm_center))
        flock_centers.append(flock_center_dist.rvs((num_flocks_sampler.rvs(), 2)))
    flock_centers = np.concatenate(flock_centers, 0)

    # flock_center_dist = scipy.stats.norm(swarm_centers, 0.1 * np.ones_like(swarm_centers))
    # flock_centers = flock_center_dist.rvs((avg_number_of_flocks, number_of_swarms, 2)).reshape(-1, 2)

    var = [[0.15 ** 2 * x_diff / 500, 0.], [0., .1 ** 2 * y_diff / 500]]

    distributions = [
        multivariate_t(flock_center, var, df=5.)
        for flock_center in flock_centers]

    samples = np.concatenate([dist.rvs(num_birds_sampler.rvs()) for dist in distributions], 0)
    # background = np.random.uniform(-3, +3, (samples.shape[0] // 10, 2))
    background = np.mgrid[-13:13:400j, -13:13:400j].reshape(2, -1).T
    samples = np.concatenate([samples, background], 0)
    return samples


def generate_radar_positions(altitude_layers, altitudes, num_radars, extent, uniform_samples=False, noise=0.04):
    """

    :param altitude_layers:
    :param altitudes:
    :param num_radars:
    :param extent:
    :return:
    """
    xmin, xmax, ymin, ymax = extent
    if uniform_samples:
        xy_train = np.random.uniform([xmin, ymin], [xmax, ymax], size=[num_radars, 2])
    else:
        xy_train = np.meshgrid(np.linspace(xmin + .1, xmax - .1, int(np.sqrt(num_radars))),
                               np.linspace(ymin + .1, ymax - .1, int(np.sqrt(num_radars))))
        xy_train = np.stack(xy_train, -1).reshape(-1, 2)
    xy_train += np.random.normal(0, noise, xy_train.shape)

    xy_train[:, 0] = np.clip(xy_train[:, 0], xmin, xmax)
    xy_train[:, 1] = np.clip(xy_train[:, 1], ymin, ymax)

    altitudes_train = np.broadcast_to(altitudes, (xy_train.shape[0], 1, altitude_layers))
    xy_train = np.broadcast_to(xy_train[..., np.newaxis], (xy_train.shape[0], 2, altitude_layers))
    train_radar_xyz = np.swapaxes(np.concatenate((xy_train, altitudes_train), 1), -2, -1).reshape(-1, 3)

    return train_radar_xyz


class ParticleCollection:
    def __init__(self, number_of_flocks, number_of_swarms, birds_per_flock,
                 vf_funs: VectorfieldFuns, tmax, dt=1e-3):
        self.id_counter = 0
        self.num_birds = number_of_flocks * number_of_swarms * birds_per_flock

        self.dt = dt
        self.tmax = tmax
        self.tsteps = np.arange(0., tmax + self.dt, self.dt)
        self.step = 0

        self.vf_funs = vf_funs

        self.pos_xy_0 = sample_particle_clusters(avg_number_of_flocks=number_of_flocks,
                                                 number_of_swarms=number_of_swarms,
                                                 birds_per_flock=birds_per_flock,
                                                 xmin=-1, xmax=1, ymin=-1, ymax=1)
        self.pos_z_0 = np.abs(np.random.normal(0.1, 0.03, (self.pos_xy_0.shape[0], 1))) + 0.05

        self.pos_xyz_0 = np.concatenate([self.pos_xy_0, self.pos_z_0], -1)

        self.particle_positions = np.zeros((len(self.tsteps), *self.pos_xyz_0.shape))
        self.particle_velocity = np.zeros((len(self.tsteps), *self.pos_xyz_0.shape))
        self.particle_positions[0, ...] = self.pos_xyz_0
        self.particle_velocity[0, ...] = self.calc_velocity(*self.pos_xyz_0.T, self.tsteps[0])

        self.simulate()

    def calc_velocity(self, x, y, z, t):
        vel = self.vf_funs.vf_3d(x, y, z, t)
        # vel = np.random.normal(vel, [0.2, 0.2, 0.2])
        return vel

    def move(self):
        cur_pos = self.particle_positions[self.step, ...]
        # vel = self.vf_funs.vf_final(cur_pos[..., 0], cur_pos[..., 1], self.tsteps[self.step])
        # w_vel = np.ones((vel.shape[0], 1)) * self.vf_funs.w(self.tsteps[self.step])

        vel = self.calc_velocity(*cur_pos.T, t=self.tsteps[self.step])

        next_pos = cur_pos + self.dt * vel
        next_vel = self.calc_velocity(*next_pos.T, self.tsteps[self.step])

        next_pos = cur_pos + self.dt * 0.5 * (vel + next_vel) / 2

        self.particle_positions[self.step + 1, ...] = next_pos
        self.particle_velocity[self.step + 1, ...] = next_vel
        self.step += 1
        return next_pos, next_vel

    def simulate(self):
        for num, t in enumerate(tqdm(self.tsteps[:-1])):
            next_pos, next_vel = self.move()

    def __getitem__(self, queried_t):
        try:
            prev_step = int(queried_t // self.dt) + 1
            diff = queried_t - self.tsteps[prev_step]
            interpolated_pos = self.particle_positions[prev_step] * diff + self.particle_positions[prev_step + 1] * (
                    1 - diff)
        except IndexError as e:
            if queried_t >= self.tmax:
                return self.particle_positions[-1]
            else:
                raise e
        return interpolated_pos

    def vel(self, queried_t):
        try:
            prev_step = int(queried_t // self.dt) + 1
            diff = queried_t - self.tsteps[prev_step]
            interpolated_pos = self.particle_velocity[prev_step] * diff + self.particle_velocity[prev_step + 1] * (
                    1 - diff)
        except IndexError as e:
            if queried_t >= self.tmax:
                return self.particle_velocity[-1]
            else:
                raise e
        return interpolated_pos

    def density_at_radars(self, time, radar_positions, radius_xy, altitude_bins) -> pd.DataFrame:
        radars_xy = np.unique(radar_positions[:, :2], axis=0)

        points = self[time]
        velocity = self.vel(time)

        volume = np.pi * radius_xy ** 2 * np.round(altitude_bins[1] - altitude_bins[0], 8)

        tree = KDTree(points[:, :2])
        points_within_xy_idx = tree.query_ball_point(radars_xy, radius_xy)

        df_list = []
        for radar_xy, idx in zip(radars_xy, points_within_xy_idx):
            points_within_radar = points[idx]
            velocity_within_radar = velocity[idx] #self.calc_velocity(points_within_radar[:, 0],
                                                  #     points_within_radar[:, 1],
                                                  #     points_within_radar[:, 2], time)
            if len(points_within_radar > 0):
                vel_u, vel_v, vel_w = binned_statistic(x=points_within_radar[:, -1],
                                                       values=[velocity_within_radar[:, 0],
                                                               velocity_within_radar[:, 1],
                                                               velocity_within_radar[:, 2]],
                                                       bins=altitude_bins, statistic="mean")[0]
            else:
                vel_u, vel_v, vel_w = np.full([3, len(altitude_bins) - 1], np.nan)
            binned_num_particles, bins = np.histogram(points_within_radar[:, -1], bins=altitude_bins)
            centroid_bins = bins[:-1] + (bins[1] - bins[0]) / 2

            true_velocity_at_radar = self.calc_velocity(np.broadcast_to(radar_xy[0], centroid_bins.shape),
                                                        np.broadcast_to(radar_xy[1], centroid_bins.shape),
                                                        centroid_bins, time)

            df_list.append(pd.DataFrame(dict(x=np.broadcast_to(radar_xy[0], centroid_bins.shape),
                                             y=np.broadcast_to(radar_xy[1], centroid_bins.shape),
                                             z=centroid_bins,
                                             t=np.broadcast_to(time, centroid_bins.shape),
                                             count=binned_num_particles,
                                             mass=binned_num_particles,
                                             density=binned_num_particles / volume,
                                             u=vel_u, v=vel_v, w=vel_w,
                                             u_atradar=true_velocity_at_radar[:, 0],
                                             v_atradar=true_velocity_at_radar[:, 1],
                                             w_atradar=true_velocity_at_radar[:, 2]))
                           )
            # assert binned_num_particles.sum() == len(idx)

        radar_df = pd.concat(df_list)
        # assert sum(len(idx) for idx in points_within_xy_idx) == radar_df["count"].sum()
        return radar_df

    def density_on_grid(self, time: float, grid_list: List[np.array]) -> pd.DataFrame:
        points = self[time]
        velocities = self.vel(time)
        # velocities = self.calc_velocity(*points.T, time)

        hist, edges = np.histogramdd(points, bins=grid_list)
        hist = np.swapaxes(hist, 0, 1)
        dx, dy, dz = [np.round(edge[1] - edge[0], 8) for edge in edges]
        volume = dx * dy * dz

        vel_u, vel_v, vel_w = binned_statistic_dd(sample=points,
                                                  values=[velocities[:, 0], velocities[:, 1], velocities[:, 2]],
                                                  bins=grid_list, statistic="mean")[0]
        vel_u, vel_v, vel_w = [np.swapaxes(vel, 0, 1) for vel in [vel_u, vel_v, vel_w]]
        # hist2 = binned_statistic_dd(sample=points,
        #                            values=None,
        #                            bins=grid_list, statistic="count")[0]

        edge_means = [np.mean(np.vstack([edge[0:-1], edge[1:]]), axis=0) for edge in edges]
        centroid_mesh = np.meshgrid(*edge_means)
        centroid_mesh_xyz = np.stack(centroid_mesh, -1)  # .reshape(-1, len(centroid_mesh))
        true_velocity_at_pos = self.calc_velocity(centroid_mesh_xyz[..., 0].flatten(),
                                                  centroid_mesh_xyz[..., 1].flatten(),
                                                  centroid_mesh_xyz[..., 2].flatten(), time)

        grid_df = pd.DataFrame(dict(x=centroid_mesh_xyz[..., 0].flatten(),
                                    y=centroid_mesh_xyz[..., 1].flatten(),
                                    z=centroid_mesh_xyz[..., 2].flatten(),
                                    t=np.broadcast_to(time, hist.flatten().shape),
                                    mass=hist.flatten(),
                                    density=hist.flatten() / volume,
                                    u=vel_u.flatten(), v=vel_v.flatten(), w=vel_w.flatten(),
                                    u_atradar=true_velocity_at_pos[:, 0],
                                    v_atradar=true_velocity_at_pos[:, 1],
                                    w_atradar=true_velocity_at_pos[:, 2]
                                    ))
        return grid_df
