"""Differentiable particle filtering."""

import torch
import numpy as np
from isbmodel.dpf import SDEModelFilter
from isbmodel.dpf.optimal_transport import transport_resample
from accelerate import Accelerator


class DeterministicParticleFilter():
    """Filtering with discrete observations, deterministic steps."""

    def __init__(self, config, sde_filter_model: SDEModelFilter, time_forward=True, device='cpu'):
        self.n_dim = config.dataset.n_dim
        self.time_diff = config.filter.time_diff
        self.n_steps = config.filter.n_steps
        self.config = config
        self.transport_eps = torch.Tensor([config.filter.transport_eps]).to(device)[0]

        time_max = self.n_steps*self.time_diff - self.time_diff
        self.time_stamps = torch.linspace(0, time_max, self.n_steps).to(device)   # fixed linear time, to be expanded later
        if not time_forward:
            self.time_stamps = torch.flip(self.time_stamps, dims=[0])
        self.time_forward = time_forward
        self.sde_filter_model = sde_filter_model
        self.model_type = self.sde_filter_model.model_type

        assert 'image' not in self.model_type, f'The ISB model for images is not supported'
        self.device = device
        print(f'Obs noise level: {self.sde_filter_model.obs_noise_level}')
        self.accelerator = Accelerator()

    def reweight(self, t, particles, obs):
        """Reweights the uncontrolled particles."""
        particle_error = particles.unsqueeze(0) - obs.unsqueeze(1) 
        obs_noise = self.sde_filter_model.eval_obs_noise(t)
        obs_noise_mult = -0.5/(obs_noise**2)
        log_weights = obs_noise_mult*torch.sum((particle_error)**2, dim=-1)
        weights_sorted, _ = torch.sort(log_weights, dim=0, descending=True)
        k_nearest = max(int(np.floor(np.sqrt(obs.shape[0]))), 1)
        log_weights_unnorm = torch.mean(weights_sorted[:k_nearest], dim=0)
        log_weights = log_weights_unnorm - torch.logsumexp(log_weights_unnorm, dim=0)
        particles, ot_matrix = self.ot_resample(particles, log_weights)
        return particles, ot_matrix, log_weights_unnorm

    def sde_flow_step(self, t,  x, dt=0.01, obs=None):
        """Stochastic step."""
        drift = self.sde_filter_model.eval_drift(t, x)
        diffusions = self.sde_filter_model.eval_diffusion(t)
        rand = torch.sqrt(dt)*diffusions*torch.randn(*x.shape, device=self.device)

        new_mean = x + dt*drift
        new_y = new_mean + rand
        next_mean = new_y + self.sde_filter_model.eval_drift(t, new_y)*dt

        diff_particles = new_mean - next_mean
        dt_stacked = dt.unsqueeze(0).repeat(x.shape[0], *self.n_dim)

        if obs is None:
            return new_y, diff_particles, dt_stacked
        
        if obs is not None:
            particles, ot_matrix, _ = self.reweight(t, new_y, obs)
            new_mean_flat = torch.flatten(new_mean, start_dim=-(len(self.n_dim)))
            new_mean = torch.bmm(ot_matrix, new_mean_flat.unsqueeze(0)).squeeze(0)
            new_mean = new_mean.reshape(-1, *self.n_dim)
            next_mean = particles + self.sde_filter_model.eval_drift(t, particles)*dt  # notice time
            diff_particles = new_mean - next_mean

            return particles, diff_particles, dt_stacked

    def ot_resample(self, particles, log_weights, stable=True):
        """Perform a resampling step.
        
        Uses entropy-regularized OT to transform weights and particles
        to uniformly-weighted particles.
        """
        transported_particles, _, transport_matrix = transport_resample(particles.unsqueeze(0), log_weights.unsqueeze(0), eps=self.transport_eps, stable=stable)
        new_particles = transported_particles.squeeze(0)
        return new_particles, transport_matrix
        
    def generate_particles(self, obs_times, obs_samples, n_particles, init_points=None):
        """Generate particles.
        
        Assumption: obs_times is a subset of self.time_stamps. 
        """
        if init_points is None:
            particles = self.sde_filter_model.generate_init_samples(n_samples=n_particles)
        else:
            particles = init_points
        particle_output = torch.empty((len(self.time_stamps), n_particles, *self.n_dim), device=self.device, dtype=particles.dtype)
        diff_output = torch.empty((len(self.time_stamps) - 1, n_particles, *self.n_dim),  device=self.device, dtype=particles.dtype)
        dt_output = torch.empty((len(self.time_stamps) - 1, n_particles, *self.n_dim),  device=self.device, dtype=particles.dtype)

        if len(obs_times) != 0: 
            self.time_stamps = self.time_stamps.to(obs_times.dtype)
        obs_idx = 0
        particle_output[0] = particles


        for i, t in enumerate(self.time_stamps[:-1]):
            if self.time_forward:
                dt = self.time_stamps[i+1] - t
            else:
                dt =  t - self.time_stamps[i+1]
            if obs_idx >= len(obs_times):
                obs_time = False
            else:
                comp_time = obs_times[obs_idx]
                obs_time = torch.isclose(t, comp_time)
            if obs_time:
                obs = obs_samples[:, obs_idx]
            else:
                obs = None

            particles, diff_particle, dt_stacked = self.sde_flow_step(t,  particles,  dt=dt, obs=obs)
            if obs_time:
                obs_idx += 1
            diff_output[i] = diff_particle
            dt_output[i] = dt_stacked
            particle_output[i+1] = particles
        times_batched = self.time_stamps.unsqueeze(1).repeat(1, n_particles)
        return particle_output, diff_output, times_batched, dt_output


            