import os
from abc import ABC, abstractmethod
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import shutil
from tensorboardX import SummaryWriter
from .networks import get_network
from utils.diff_ops import jacobian, divergence, curl2d_fdiff
from utils.model_utils import sample_uniform_2D, sample_random_2D
from utils.vis_utils import draw_scalar_field2D, draw_vector_field2D


class NeuralFluidABC(ABC):
    def __init__(self, cfg):
        self.cfg = cfg
        self.dt = cfg.dt
        self.visc = cfg.visc
        self.diff = cfg.diff
        self.max_n_iters = cfg.max_n_iters
        self.sample_resolution = cfg.sample_resolution
        self.vis_resolution = cfg.vis_resolution
        self.timestep = 0
        self.boundary_cond = cfg.boundary_cond
        self.mode = cfg.mode
        self.tb = None
        self.sample_pattern = cfg.sample
        self.use_density = cfg.use_density

        self.loss_record = [10000, 0] # for early stopping condition

        # neural implicit network for density, velocity and pressure field
        if self.use_density:
            self.density_field = get_network(cfg, 2, 1).cuda()
            self.density_field_prev = get_network(cfg, 2, 1).cuda()
        self.velocity_field = get_network(cfg, 2, 2).cuda()
        self.velocity_field_prev = get_network(self.cfg, 2, 2).cuda()
        self._set_require_grads(self.velocity_field_prev, False)
        self.device = torch.device("cuda:0")

    @property
    def _trainable_networks(self):
        return {'velocity': self.velocity_field}

    def create_optimizer(self, use_scheduler=True, gamma=0.1, patience=500, min_lr=1e-8):
        self.loss_record = [10000, 0]
        # optimizer: use only one optimizer?
        param_list = []
        for net in self._trainable_networks.values():
            param_list.append({"params": net.parameters(), "lr": self.cfg.lr})
        self.optimizer = torch.optim.Adam(param_list)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=gamma, 
            min_lr=min_lr, patience=patience, verbose=True) if use_scheduler else None
        # self.scheduler = torch.optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=1e-7, max_lr=self.cfg.lr, 
        #     step_size_up=1000, step_size_down=1000, mode='triangular2', cycle_momentum=False) if use_scheduler else None

    def create_tb(self, name, overwrite=True):
        """create tensorboard log"""
        self.log_path = os.path.join(self.cfg.log_dir, name)
        if os.path.exists(self.log_path) and overwrite:
            shutil.rmtree(self.log_path, ignore_errors=True)
        return SummaryWriter(self.log_path)

    @abstractmethod
    def step(self):
        raise NotImplementedError

    def update_network(self, loss_dict):
        """update network by back propagation"""
        loss = sum(loss_dict.values())
        self.optimizer.zero_grad()
        loss.backward()
        if self.cfg.grad_clip > 0:
            param_list = []
            for net in self._trainable_networks.values():
                param_list = param_list + list(net.parameters())
            torch.nn.utils.clip_grad_norm_(param_list, 0.1)
        self.optimizer.step()
        if self.scheduler is not None:
            self.scheduler.step(loss_dict['main'])
            # self.scheduler.step()
        self.update_loss_record(loss_dict['main'].item())
    
    def update_loss_record(self, loss_val):
        if loss_val < self.loss_record[0]:
            self.loss_record = [loss_val, 0]
        else:
            self.loss_record[1] += 1
    
    def _set_require_grads(self, model, require_grad):
        for p in model.parameters():
            p.requires_grad_(require_grad)
    
    def save_ckpt(self, name=None):
        """save checkpoint during training for future restore"""
        if name is None:
            save_path = os.path.join(self.cfg.model_dir, f"ckpt_step_t{self.timestep:03d}.pth")
        else:
            save_path = os.path.join(self.cfg.model_dir, f"ckpt_{name}.pth")

        save_dict = {}
        for name, net in self._trainable_networks.items():
            save_dict.update({f'net_{name}': net.cpu().state_dict()})
            net.cuda()
        save_dict.update({'timestep': self.timestep})

        torch.save(save_dict, save_path)
    
    def load_ckpt(self, name):
        """load saved checkpoint"""
        if type(name) is int:
            load_path = os.path.join(self.cfg.model_dir, f"ckpt_step_t{name:03d}.pth")
        else:
            load_path = os.path.join(self.cfg.model_dir, f"ckpt_{name}.pth")
        checkpoint = torch.load(load_path)

        for name, net in self._trainable_networks.items():
            net.load_state_dict(checkpoint[f'net_{name}'])
        self.timestep = checkpoint['timestep']

    @classmethod
    def _training_loop(cls, func):
        """a decorator function that warps a function inside a training loop"""
        tag = func.__name__
        def loop(self, *args, **kwargs):
            pbar = tqdm(range(self.max_n_iters))
            self.tb.train_iter = 0
            for i in pbar:
                loss_dict = func(self, *args, **kwargs)
                self.update_network(loss_dict)

                loss_value = {k: v.item() for k, v in loss_dict.items()}

                self.tb.add_scalars(tag, loss_value, global_step=i)
                self.tb.train_iter += 1
                pbar.set_description(f"{tag}[{self.timestep}]")
                pbar.set_postfix(loss_value)

                if self.cfg.early_stop and self.optimizer.param_groups[0]['lr'] <= 1.1e-8 and self.loss_record[1] >= 500:
                    tqdm.write(f"early stopping at iteration {i}")
                    break
        return loop

    def sample_in_training(self):
        if self.sample_pattern == 'random':
            samples = sample_random_2D(self.sample_resolution ** 2, device=self.device).requires_grad_(True)
        elif self.sample_pattern == 'uniform':
            samples = sample_uniform_2D(self.sample_resolution, device=self.device).requires_grad_(True)
        elif self.sample_pattern == 'random+uniform':
            samples = torch.cat([sample_random_2D(self.sample_resolution ** 2, device=self.device),
                        sample_uniform_2D(self.sample_resolution, device=self.device).view(-1, 2)], dim=0).requires_grad_(True)
        else:
            raise NotImplementedError
        return samples

    def sample_velocity_field(self, resolution, to_numpy=True, with_boundary=False, return_samples=False, require_grad=False):
        grid_samples = sample_uniform_2D(resolution, with_boundary=with_boundary, device=self.device)
        if require_grad:
            grid_samples = grid_samples.requires_grad_(True)

        out = self.velocity_field(grid_samples)
        if to_numpy:
            out = out.detach().cpu().numpy()
            grid_samples = grid_samples.detach().cpu().numpy()
        if return_samples:
            return out, grid_samples
        return out

    def sample_density_field(self, resolution, to_numpy=True, with_boundary=False, return_samples=False, require_grad=False):
        grid_samples = sample_uniform_2D(resolution, with_boundary=with_boundary, device=self.device)
        if require_grad:
            grid_samples = grid_samples.requires_grad_(True)

        out = self.density_field(grid_samples)
        if to_numpy:
            out = out.detach().cpu().numpy()
            grid_samples = grid_samples.detach().cpu().numpy()
        if return_samples:
            return out, grid_samples
        return out

    def draw(self, tag, resolution, **kwargs):
        func_str = f'draw_{tag}'
        try:
            return getattr(self, func_str)(resolution, **kwargs)
        except Exception as e:
            print(f"no method named '{func_str}'.")
            pass

    def draw_velocity(self, resolution):
        grid_values, grid_samples = self.sample_velocity_field(resolution, to_numpy=True, with_boundary=True, return_samples=True)
        x, y = grid_samples[..., 0], grid_samples[..., 1]
        fig = draw_vector_field2D(grid_values[..., 0], grid_values[..., 1], x, y)
        return fig

    def draw_vorticity(self, resolution, vmin=0, vmax=10):
        grid_values = self.sample_velocity_field(resolution, to_numpy=False)
        grid_values = np.transpose(grid_values.detach().cpu().numpy(), (1,0,2))
        curl = np.abs(curl2d_fdiff(grid_values, 2.0 / resolution)) # NOTE: use finite difference
        fig = draw_scalar_field2D(curl, vmin=vmin, vmax=vmax)
        # x, y = grid_samples[..., 0].detach().cpu().numpy(), grid_samples[..., 1].detach().cpu().numpy()
        # fig = draw_vorticity_field2D(curl, x, y)
        return fig

    def draw_density(self, resolution):
        p = self.sample_density_field(resolution, to_numpy=True)[..., 0]
        fig = draw_scalar_field2D(p)
        return fig

    # def compute_kinetic_energy(self, resolution):
    #     grid_values = self.sample_velocity_field(resolution, to_numpy=True, with_boundary=True)
    #     Ek = 0.5 * np.sum(grid_values ** 2)
    #     return Ek


class NeuralFluidBase(NeuralFluidABC):
    def __init__(self, cfg):
        super(NeuralFluidBase, self).__init__(cfg)

    @NeuralFluidABC._training_loop
    def _add_source(self, source_func, is_init=True):
        """forward computation for add source"""
        if (self.tb.train_iter == 0 or (self.tb.train_iter + 1) % self.cfg.vis_frequency == 0):
            self._vis_add_source_v(source_func, is_init)
        
        samples = self.sample_in_training()

        out_rand = self.velocity_field(samples)
        if is_init:
            target_rand_val = source_func(samples)
        else:
            target_rand_val = source_func(samples) + self.velocity_field_prev(samples).detach()

        loss_random = F.mse_loss(out_rand, target_rand_val)
        loss_dict = {'main': loss_random}
        if self.cfg.grad_sup:
            target_rand_grad = jacobian(target_rand_val, samples)[0][..., [0, 1], [0, 1]]
            out_grad = jacobian(out_rand, samples)[0][..., [0, 1], [0, 1]]
            grad_loss = F.mse_loss(out_grad, target_rand_grad)
            loss_dict.update({"grad_mse": grad_loss})

        return loss_dict
    
    def add_source(self, attr: str, source_func, is_init=True):
        self.tb = self.create_tb("add_source")
        self.create_optimizer()
        self.source_func = source_func
        self.velocity_field_prev.load_state_dict(self.velocity_field.state_dict())
        self._add_source(source_func, is_init)
        self.save_ckpt('add_source')

    @NeuralFluidABC._training_loop
    def _add_source_density(self, source_func, is_init=True):
        """forward computation for add source"""
        if (self.tb.train_iter == 0 or (self.tb.train_iter + 1) % self.cfg.vis_frequency == 0):
            self._vis_add_source_d(source_func)

        samples = self.sample_in_training()

        out = self.density_field(samples).squeeze(-1)
        if is_init:
            target_val = source_func(samples)
        else:
            target_val = source_func(samples) + self.density_field_prev(samples).detach()

        loss = F.mse_loss(out, target_val)
        loss_dict = {'main': loss}
        return loss_dict
    
    def add_source_density(self, attr: str, source_func, is_init=True):
        self.tb = self.create_tb("add_source_density")
        self.create_optimizer()
        self.density_field_prev.load_state_dict(self.density_field.state_dict())
        self._add_source_density(source_func, is_init)
        self.save_ckpt('add_source_density')

    def step(self):
        pass

    ################# visualization during training #####################
    def _vis_add_source_v(self, source_func, is_init):
        grid_samples = sample_uniform_2D(self.vis_resolution, device=self.device).cuda().requires_grad_(True)
        if is_init:
            gt_u = source_func(grid_samples)
        else:
            gt_u = source_func(grid_samples) + self.velocity_field_prev(grid_samples)
        out_u = self.velocity_field(grid_samples)
        # gt_div_u = divergence(gt_u, grid_samples).detach()
        gt_grad = jacobian(gt_u, grid_samples)[0][..., [0, 1], [0, 1]]
        out_grad = jacobian(out_u, grid_samples)[0][..., [0, 1], [0, 1]]
        loss_grid = torch.mean((out_u - gt_u) ** 2, dim=-1).detach().cpu().numpy()
        self.tb.add_figure('u_mse', draw_scalar_field2D(loss_grid), global_step=self.tb.train_iter)
        
        gt_u = gt_u.detach().cpu().numpy()
        out_u = out_u.detach().cpu().numpy()
        gt_grad = gt_grad.detach().cpu().numpy()
        out_grad = out_grad.detach().cpu().numpy()
        grid_samples = grid_samples.detach().cpu().numpy()
        x, y = grid_samples[..., 0], grid_samples[..., 1]
        self.tb.add_figure('u_gt', draw_vector_field2D(gt_u[..., 0], gt_u[..., 1], x, y), global_step=self.tb.train_iter)
        self.tb.add_figure('u_pred', draw_vector_field2D(out_u[..., 0], out_u[..., 1], x, y), global_step=self.tb.train_iter)
        self.tb.add_figure('duxdx_gt', draw_scalar_field2D(gt_grad[..., 0]), global_step=self.tb.train_iter)
        self.tb.add_figure('duydy_gt', draw_scalar_field2D(gt_grad[..., 1]), global_step=self.tb.train_iter)
        # self.tb.add_figure('div_u_gt', draw_scalar_field2D(gt_div_u[..., 0].detach().cpu().numpy()), global_step=self.tb.train_iter)
        # self.tb.add_figure('div_u_gt', draw_scalar_field2D(np.sum(gt_grad, axis=-1)), global_step=self.tb.train_iter)
        self.tb.add_figure('duxdx_pred', draw_scalar_field2D(out_grad[..., 0]), global_step=self.tb.train_iter)
        self.tb.add_figure('duydy_pred', draw_scalar_field2D(out_grad[..., 1]), global_step=self.tb.train_iter)
        # self.tb.add_figure('div_u_pred', draw_scalar_field2D(np.sum(out_grad, axis=-1)), global_step=self.tb.train_iter)

    def _vis_add_source_d(self, source_func):
        fig = self.draw_density(self.sample_resolution)
        self.tb.add_figure('density', fig, global_step=self.tb.train_iter)
