import os
import time
import torch
import torch.nn.functional as F
from tqdm import tqdm
import shutil
from tensorboardX import SummaryWriter
from networks import get_network
from utils import *
from diff_ops import gradient


class NeuralAdvection(object):
    """advection equation with constant velocity"""
    def __init__(self, cfg):
        self.device = torch.device("cuda:0")
        self.log_file = os.path.join(cfg.exp_dir, "time_log.txt")

        self.cfg = cfg
        self.sdim = cfg.sdim
        self.vel = cfg.vel
        self.length = cfg.length
        self.dt = cfg.dt
        self.max_n_iters = cfg.max_n_iters
        self.sample_resolution = cfg.sample_resolution
        self.timestep = 0
        self.stage = None
        self.time_integrator = cfg.time_integrator
        self.boundary_cond = cfg.boundary_cond

        # neural implicit network for scalar field
        self.field = get_network(cfg, cfg.sdim, 1).cuda()
        self.field_prev = get_network(cfg, cfg.sdim, 1).cuda()
        self.create_optimizer()
        self.loss_record = [10000, 0] # for early stopping condition

        model_size = calculate_params_size(self.field)
        print(f"model size: {model_size / 1024:.3f}KB")

    def create_optimizer(self, use_scheduler=True, gamma=0.1):
        """create optimizer"""
        self.optimizer = torch.optim.Adam(self.field.parameters(), lr=self.cfg.lr)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=gamma, 
            min_lr=1e-12, patience=500, verbose=True) if use_scheduler else None
        self.loss_record = [10000, 0]
    
    def create_tb(self, name, overwrite=True):
        """create tensorboard log"""
        log_path = os.path.join(self.cfg.log_dir, name)
        if os.path.exists(log_path) and overwrite:
            shutil.rmtree(log_path)
        return SummaryWriter(log_path)

    def save_ckpt(self, name=None):
        """save checkpoint during training for future restore"""
        if name is None:
            save_path = os.path.join(self.cfg.model_dir, f"ckpt_step_t{self.timestep:03d}.pth")
        else:
            save_path = os.path.join(self.cfg.model_dir, f"ckpt_{name}.pth")

        torch.save({
            'field_state_dict': self.field.cpu().state_dict(),
        }, save_path)

        self.field.cuda()
    
    def load_ckpt(self, name):
        """save saved checkpoint"""
        if type(name) is int:
            load_path = os.path.join(self.cfg.model_dir, f"ckpt_step_t{name:03d}.pth")
        else:
            load_path = os.path.join(self.cfg.model_dir, f"ckpt_{name}.pth")
        checkpoint = torch.load(load_path)

        self.field.load_state_dict(checkpoint['field_state_dict'])

    def update_network(self, loss_dict):
        """update network by back propagation"""
        loss = sum(loss_dict.values())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        if self.scheduler is not None:
            self.scheduler.step(loss_dict['main'])

        self.update_loss_record(loss_dict['main'].item())

    def update_loss_record(self, loss_val):
        if loss_val < self.loss_record[0]:
            self.loss_record = [loss_val, 0]
        else:
            self.loss_record[1] += 1

    def _training_loop(func):
        """a decorator function that warps a function inside a training loop

        Args:
            func ([type]): a function does forward computation 
                and returns a dict of losses
        """
        def loop(self, *args, **kwargs):
            pbar = tqdm(range(self.max_n_iters))
            self.tb = self.create_tb(self.stage + f"_t{self.timestep}")
            since = time.time()
            for i in pbar:
                loss_dict = func(self, *args, **kwargs)
                self.update_network(loss_dict)

                loss_value = {k: v.item() for k, v in loss_dict.items()}
                self.tb.add_scalars(self.stage, loss_value, global_step=i)
                pbar.set_description(f"{self.stage}[{self.timestep}]")
                pbar.set_postfix(loss_value)

                if i == 0 or (i + 1) % self.cfg.vis_frequency == 0:
                    fig = self.draw_field(self.sample_resolution)
                    self.tb.add_figure(self.stage, fig, global_step=i)
                
                if self.cfg.early_stop and self.optimizer.param_groups[0]['lr'] <= 1.1e-8 and self.loss_record[1] >= 500:
                    pbar.write(f"early stopping at iteration {i}")
                    break
            time_cost = time.time() - since
            time_cost = round(time_cost, 5)
            with open(self.log_file, "a") as fp:
                print(time_cost, file=fp)
                print(f"t{self.timestep:03d}:{time_cost}")
        return loop
    
    @_training_loop
    def _add_source(self, source_func):
        """forward computation for add source"""
        random_samples = sample_random(self.sample_resolution ** self.sdim, self.sdim, self.length).cuda()
        target_rand_val = source_func(random_samples)
        out_rand = self.field(random_samples)
        loss_random = F.mse_loss(out_rand, target_rand_val)

        loss_dict = {'main': loss_random}
        return loss_dict
    
    def add_source(self, source_func):
        """fit network to a given source function"""
        self.stage = "add_source"
        self.create_optimizer()
        self._add_source(source_func)
        self.save_ckpt('add_source') # save weights after training

    @_training_loop
    def _advect(self):
        """forward computation for advect"""
        random_samples = sample_random(self.sample_resolution ** self.sdim, self.sdim, self.length, device=self.device).requires_grad_(True)
        samples = random_samples

        # with torch.no_grad():
        prev_u = self.field_prev(samples)
        curr_u = self.field(samples)
        dudt = (curr_u - prev_u) / self.dt # (N, sdim)

        if self.time_integrator == 'explicit':
            grad_u0 = gradient(prev_u, samples).detach()
            loss = torch.mean((dudt + self.vel * grad_u0) ** 2)
        elif self.time_integrator == 'implicit':
            grad_u = gradient(curr_u, samples) # (N, sdim)
            loss = torch.mean((dudt + self.vel * grad_u) ** 2)
        elif self.time_integrator == 'midpoint':
            grad_u = gradient(curr_u, samples)
            grad_u0 = gradient(prev_u, samples).detach()
            loss = torch.mean((dudt + self.vel * (grad_u + grad_u0) / 2.) ** 2)
        else:
            raise NotImplementedError
        loss_dict = {'main': loss}

        # boundary constraint
        if self.boundary_cond == 'none':
            bc_loss = None
        elif self.boundary_cond == 'zero':
            # FIXME: hard-coded zero boundary condition to sample 1% points near boundary
            #        and fixed factor 1.0 for boundary loss
            boundary_samples = sample_boundary(max(self.sample_resolution ** self.sdim // 100, 10), self.sdim, self.length, device=self.device)
            bound_u = self.field(boundary_samples)
            bc_loss = torch.mean(bound_u ** 2) * 1.
        else:
            raise NotImplementedError
        if bc_loss is not None:
            loss_dict.update({'bc': bc_loss})

        return loss_dict

    def advect(self):
        """advection: dudt = -(vel \cdot grad)u"""
        self.stage = "advect"
        self.field_prev.load_state_dict(self.field.state_dict())
        self.create_optimizer()
        self._advect()

    def sample_field(self, resolution, return_grad=False, to_numpy=True):
        """sample current field with uniform grid points"""
        grid_samples = sample_uniform(resolution, self.sdim, self.length).cuda().requires_grad_(True)
        out = self.field(grid_samples).squeeze(-1)
        if return_grad:
            grad_x = gradient(out, grid_samples)[..., 0]
        if to_numpy:
            out = out.detach().cpu().numpy()
            grid_samples = grid_samples.detach().cpu().numpy()
            if return_grad:
                grad_x = grad_x.detach().cpu().numpy()
        if return_grad:
            return out, grid_samples, grad_x
        return out, grid_samples

    def draw_field(self, resolution, y_max=None):
        """draw current field with uniform grid samples"""
        grid_values, grid_samples = self.sample_field(resolution, to_numpy=True)

        if self.sdim == 1:
            fig = draw_scalar_field1D(grid_samples[..., 0], grid_values, y_max=y_max)
        elif self.sdim == 2:
            fig = draw_scalar_field2D(grid_values)
        else:
            raise NotImplementedError
        return fig
