import os
import numpy as np
import torch
import torch.nn.functional as F
from .base import NeuralFluidBase
from .networks import get_network
from .laplacian_solver import factorized_laplacian_solver
from utils.diff_ops import curl2d_fdiff, laplace, divergence, jacobian, gradient, curl2d
from utils.model_utils import sample_random_2D, sample_boundary_separate, sample_uniform_2D
from utils.vis_utils import draw_scalar_field2D, draw_vector_field2D, save_figure


class NeuralFluidSplit(NeuralFluidBase):
    def __init__(self, cfg):
        super(NeuralFluidSplit, self).__init__(cfg)

        self.pressure_field = get_network(cfg, 2, 1).cuda()

        # self.use_discrete_pressure = cfg.use_disc_p

        if self.cfg.debug:
            self.debug_dir = os.path.join(self.cfg.results_dir, 'debug')
            if not os.path.exists(self.debug_dir):
                os.makedirs(self.debug_dir)

    @property
    def _trainable_networks(self):
        if self.use_density:
            return {'velocity': self.velocity_field, 'pressure': self.pressure_field, 'density': self.density_field}
        return {'velocity': self.velocity_field, 'pressure': self.pressure_field}

    def step(self):
        self.tb = self.create_tb(f"t{self.timestep:03d}")
        self.create_optimizer()
        self.advect_velocity()
        
        self.create_optimizer()
        self.diffuse_velocity()

        self.create_optimizer()
        self.solve_pressure()

        self.create_optimizer()
        self.project_velocity()

        if self.use_density:
            self.create_optimizer()
            self.advect_density()

    @NeuralFluidBase._training_loop
    def _advect_density(self, attr='density'):
        """velocity advection: dudt = -(u\cdot grad)u"""
        if (self.tb.train_iter == 0 or (self.tb.train_iter + 1) % self.cfg.vis_frequency == 0):
            self._vis_advect_density()

        samples = self.sample_in_training()

        # dudt
        with torch.no_grad():
            vel = self.velocity_field(samples).detach()
        curr_d = self.density_field(samples)

        if self.cfg.time_integration == 'semi_lag':
            # backtracking
            # backtracked_position = samples - prev_u[:, [1,0]] * self.cfg.dt
            backtracked_position = samples - vel * self.cfg.dt
            backtracked_position = torch.clamp(backtracked_position, min=-1.0, max=1.0) # FIXME: this looks not right, points that are backtraced outside sould be zero
            
            with torch.no_grad():
                advected_d = self.density_field_prev(backtracked_position).detach()

            loss = torch.mean((curr_d - advected_d) ** 2)
            loss_dict = {'main': loss}

        else:
            raise NotImplementedError

        return loss_dict

    def advect_density(self):
        self.density_field_prev.load_state_dict(self.density_field.state_dict())
        self._advect_density(attr='density')

    @NeuralFluidBase._training_loop
    def _advect_velocity(self, attr='velocity'):
        """velocity advection: dudt = -(u\cdot grad)u"""
        if (self.tb.train_iter == 0 or (self.tb.train_iter + 1) % self.cfg.vis_frequency == 0):
            self._vis_advect_velocity()

        # samples = sample_random_2D(self.sample_resolution ** 2, device=self.device).requires_grad_(True) # FIXME: only random samples?
        samples = self.sample_in_training()

        # dudt
        with torch.no_grad():
            prev_u = self.velocity_field_prev(samples).detach()
        curr_u = self.velocity_field(samples)

        if self.cfg.time_integration == 'implicit':
            # solve (u \cdot grad)u
            dudt = (curr_u - prev_u) / self.dt
            jac_u, _ = jacobian(curr_u, samples) # (N, 2, 2)
            u_gradu = torch.sum(curr_u.unsqueeze(1) * jac_u, dim=-1) # (..., 2)

            loss = torch.mean((dudt + u_gradu) ** 2)
            loss_dict = {'main': loss}

        elif self.cfg.time_integration == 'semi_lag':
            # backtracking
            # backtracked_position = samples - prev_u[:, [1,0]] * self.cfg.dt
            backtracked_position = samples - prev_u * self.cfg.dt
            backtracked_position = torch.clamp(backtracked_position, min=-1.0, max=1.0) # FIXME: this looks not right, points that are backtraced outside sould be zero
            
            with torch.no_grad():
                advected_u = self.velocity_field_prev(backtracked_position).detach()

            loss = torch.mean((curr_u - advected_u) ** 2)
            loss_dict = {'main': loss}

        else:
            raise NotImplementedError

        if self.boundary_cond == 'zero':
            # FIXME: hard-coded zero boundary condition to sample 1% points near boundary
            #        and fixed factor 1.0 for boundary loss
            bc_sample_x = sample_boundary_separate(samples.shape[0] // 100, side='horizontal', device=self.device).requires_grad_(True)
            bc_sample_y = sample_boundary_separate(samples.shape[0] // 100, side='vertical', device=self.device).requires_grad_(True)
            vel_x = self.velocity_field(bc_sample_x)[..., 0]
            vel_y = self.velocity_field(bc_sample_y)[..., 1]
            bc_loss = (torch.mean(vel_x ** 2) + torch.mean(vel_y ** 2)) * 1.0
            loss_dict.update({"bc": bc_loss})

        return loss_dict

    def advect_velocity(self):
        self.velocity_field_prev.load_state_dict(self.velocity_field.state_dict())
        self._advect_velocity(attr='velocity')
        if self.cfg.debug:
            fig = self.draw_velocity(self.cfg.vis_resolution)
            save_path = os.path.join(self.debug_dir, f't{self.timestep:03d}_advect.png')
            save_figure(fig, save_path)
    
    @NeuralFluidBase._training_loop
    def _diffuse_velocity(self, attr='velocity'):
        raise NotImplementedError

    def diffuse_velocity(self):
        pass
        # self._diffuse_velocity(attr='velocity')

    @NeuralFluidBase._training_loop
    def _project_velocity(self, attr='velocity'):
        """projection step for velocity: u <- u - grad(p)"""
        if (self.tb.train_iter == 0 or (self.tb.train_iter + 1) % self.cfg.vis_frequency == 0):
            self._vis_project_velocity()

        # samples = sample_random_2D(self.sample_resolution ** 2, device=self.device).requires_grad_(True) # FIXME: only random samples?
        samples = self.sample_in_training()

        with torch.no_grad():
            prev_u = self.velocity_field_prev(samples).detach()

        # if self.use_discrete_pressure:
        #     grad_p = F.grid_sample(self.grad_p_grid.permute(2, 0, 1).unsqueeze(0), 
        #         samples.flip(-1).unsqueeze(0).unsqueeze(0), align_corners=False) # FIXME: 1) if flip ok? 2) border values might not be exact 0
        #     grad_p = grad_p.squeeze(0).permute(1, 2, 0).squeeze(0) # (N, 2)
        # else:
        p = self.pressure_field(samples)
        grad_p = gradient(p, samples).detach()

        target_u = prev_u - grad_p
        curr_u = self.velocity_field(samples)
        loss = torch.mean((curr_u - target_u) ** 2)
        loss_dict = {'main': loss}

        if self.boundary_cond == 'zero':
            # FIXME: hard-coded zero boundary condition to sample 1% points near boundary
            #        and fixed factor 1.0 for boundary loss
            bc_sample_x = sample_boundary_separate(samples.shape[0] // 100, side='horizontal', device=self.device).requires_grad_(True)
            bc_sample_y = sample_boundary_separate(samples.shape[0] // 100, side='vertical', device=self.device).requires_grad_(True)
            vel_x = self.velocity_field(bc_sample_x)[..., 0]
            vel_y = self.velocity_field(bc_sample_y)[..., 1]
            bc_loss = (torch.mean(vel_x ** 2) + torch.mean(vel_y ** 2)) * 1.0
            loss_dict.update({"bc": bc_loss})
        return loss_dict

    def project_velocity(self):
        self.velocity_field_prev.load_state_dict(self.velocity_field.state_dict())
        self._project_velocity(attr='velocity')
        if self.cfg.debug:
            fig = self.draw_velocity(self.cfg.vis_resolution)
            save_path = os.path.join(self.debug_dir, f't{self.timestep:03d}_project.png')
            save_figure(fig, save_path)

    @NeuralFluidBase._training_loop
    def _solve_pressure(self, attr='pressure'):
        # FIXME: consider directly model grad_p
        """forward computation for solve pressure: div u = lap P."""
        if (self.tb.train_iter == 0 or (self.tb.train_iter + 1) % self.cfg.vis_frequency == 0):
            self._vis_solve_pressure()

        # samples = sample_random_2D(self.sample_resolution ** 2, device=self.device).requires_grad_(True) # FIXME: only random samples?
        samples = self.sample_in_training()

        out_u = self.velocity_field(samples)
        div_u = divergence(out_u, samples).detach()
        out_p = self.pressure_field(samples)
        lap_p = laplace(out_p, samples)

        loss = torch.mean((div_u - lap_p) ** 2) # FIXME: assume rho=1 here
        loss_dict = {'main': loss}

        if self.boundary_cond != 'none': # NOTE: neumann boundary condition, grad(p)\cdot norm(p) = 0
            bc_sample_x = sample_boundary_separate(self.sample_resolution ** 2 // 100, side='horizontal', device=self.device).requires_grad_(True)
            bc_sample_y = sample_boundary_separate(self.sample_resolution ** 2 // 100, side='vertical', device=self.device).requires_grad_(True)
            grad_px = gradient(self.pressure_field(bc_sample_x), bc_sample_x)[..., 0]
            grad_py = gradient(self.pressure_field(bc_sample_y), bc_sample_y)[..., 1]

            bc_loss = torch.mean(grad_px ** 2) + torch.mean(grad_py ** 2)
            loss_dict.update({'bc': bc_loss})

        return loss_dict

    def solve_pressure(self):
        """solve pressure from velocity field: div u = lap P."""
        self._solve_pressure(attr='pressure')
        if self.cfg.debug:
            fig = self.draw_pressure(self.cfg.vis_resolution)
            save_path = os.path.join(self.debug_dir, f't{self.timestep:03d}_pressure.png')
            save_figure(fig, save_path)
    
    def sample_pressure_field(self, resolution, to_numpy=True, with_boundary=False, return_samples=False, require_grad=False):
        grid_samples = sample_uniform_2D(resolution, with_boundary=with_boundary, device=self.device)
        if require_grad:
            grid_samples = grid_samples.requires_grad_(True)

        out = self.pressure_field(grid_samples)
        if to_numpy:
            out = out.detach().cpu().numpy()
            grid_samples = grid_samples.detach().cpu().numpy()
        if return_samples:
            return out, grid_samples
        return out

    def draw_pressure(self, resolution):
        p = self.sample_pressure_field(resolution, to_numpy=True)
        fig = draw_scalar_field2D(p)
        return fig
    
    ################# visualization during training #####################
    def _vis_advect_density(self):
        fig = self.draw_density(self.sample_resolution)
        self.tb.add_figure('density', fig, global_step=self.tb.train_iter)

    def _vis_advect_velocity(self):
        grid_samples = sample_uniform_2D(self.vis_resolution, device=self.device).requires_grad_(True)
        with torch.no_grad():
            prev_u_grid = self.velocity_field_prev(grid_samples).detach()
        curr_u_grid = self.velocity_field(grid_samples)

        backtracked_position = grid_samples - prev_u_grid * self.cfg.dt
        backtracked_position = torch.clamp(backtracked_position, min=-1.0, max=1.0) # FIXME: this looks not right, points that are backtraced outside sould be zero
        
        with torch.no_grad():
            advected_u = self.velocity_field_prev(backtracked_position).detach()

        loss_grid = torch.mean((curr_u_grid - advected_u) ** 2, dim=-1).detach().cpu().numpy()

        curr_u_grid = curr_u_grid.detach().cpu().numpy()
        grid_samples = grid_samples.detach().cpu().numpy()
        x, y = grid_samples[..., 0], grid_samples[..., 1]
        self.tb.add_figure('adv_mse', draw_scalar_field2D(loss_grid), global_step=self.tb.train_iter)
        self.tb.add_figure('adv_curr_u', draw_vector_field2D(curr_u_grid[..., 0], curr_u_grid[..., 1], x, y), global_step=self.tb.train_iter)

    def _vis_solve_pressure(self):
        grid_samples = sample_uniform_2D(self.vis_resolution, device=self.device).requires_grad_(True)
        out_u = self.velocity_field(grid_samples)
        div_u = divergence(out_u, grid_samples).detach()
        out_p = self.pressure_field(grid_samples)
        lap_p = laplace(out_p, grid_samples)
        grad_p = gradient(out_p, grid_samples)
        mse = (div_u - lap_p) ** 2
        # grid_p = self.p_grid

        out_u = out_u.detach().cpu().numpy()
        grid_samples = grid_samples.detach().cpu().numpy()
        x, y = grid_samples[..., 0], grid_samples[..., 1]
        self.tb.add_figure('pre_out_u', draw_vector_field2D(out_u[..., 0], out_u[..., 1], x, y), global_step=self.tb.train_iter)
        self.tb.add_figure('pre_div_u', draw_scalar_field2D(div_u[..., 0].detach().cpu().numpy()), global_step=self.tb.train_iter)
        self.tb.add_figure('pre_p_lap', draw_scalar_field2D(lap_p[..., 0].detach().cpu().numpy()), global_step=self.tb.train_iter)
        # self.tb.add_figure('p_grid', draw_scalar_field2D(grid_p[..., 0].detach().cpu().numpy()), global_step=self.tb.train_iter)
        self.tb.add_figure('pre_p', draw_scalar_field2D(out_p[..., 0].detach().cpu().numpy()), global_step=self.tb.train_iter)
        self.tb.add_figure('pre_p_gradx', draw_scalar_field2D(grad_p[..., 0].detach().cpu().numpy()), global_step=self.tb.train_iter)
        self.tb.add_figure('pre_p_grady', draw_scalar_field2D(grad_p[..., 1].detach().cpu().numpy()), global_step=self.tb.train_iter)
        self.tb.add_figure('pre_mse', draw_scalar_field2D(mse[..., 0].detach().cpu().numpy()), global_step=self.tb.train_iter)

    def _vis_project_velocity(self):
        grid_samples = sample_uniform_2D(self.vis_resolution, device=self.device).requires_grad_(True)

        with torch.no_grad():
            prev_u = self.velocity_field_prev(grid_samples).detach()

        # if self.use_discrete_pressure:
        #     grad_p = F.grid_sample(self.grad_p_grid.permute(2, 0, 1).unsqueeze(0), 
        #         grid_samples.flip(-1).unsqueeze(0), align_corners=False) # FIXME: 1) if flip ok? 2) border values might not be exact 0
        #     grad_p = grad_p.squeeze(0).permute(1, 2, 0)
        # else:
        p = self.pressure_field(grid_samples)
        grad_p = gradient(p, grid_samples).detach()
        target_u = prev_u - grad_p
        curr_u = self.velocity_field(grid_samples)
        loss_grid = torch.sum((curr_u - target_u) ** 2, dim=-1).detach().cpu().numpy()
        # grid_u = self.u_grid

        grad_p = grad_p.detach().cpu().numpy()
        target_u = target_u.detach().cpu().numpy()
        curr_u = curr_u.detach().cpu().numpy()
        grid_samples = grid_samples.detach().cpu().numpy()
        x, y = grid_samples[..., 0], grid_samples[..., 1]
        self.tb.add_figure('proj_grad_p', draw_vector_field2D(grad_p[..., 0], grad_p[..., 1], x, y), global_step=self.tb.train_iter)
        self.tb.add_figure('proj_target_u', draw_vector_field2D(target_u[..., 0], target_u[..., 1], x, y), global_step=self.tb.train_iter)
        self.tb.add_figure('proj_out_u', draw_vector_field2D(curr_u[..., 0], curr_u[..., 1], x, y), global_step=self.tb.train_iter)
        self.tb.add_figure('proj_mse', draw_scalar_field2D(loss_grid), global_step=self.tb.train_iter)
