import os
from abc import ABC, abstractmethod
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import shutil
from tensorboardX import SummaryWriter
from models.networks import get_network
from utils import sample_random, sample_uniform

class NeuralPDEABC(ABC):
    def __init__(self, cfg):
        self.cfg = cfg
        self.dt = cfg.dt
        self.max_n_iters = cfg.max_n_iters
        self.sample_resolution = cfg.sample_resolution
        self.vis_resolution = cfg.vis_resolution
        self.timestep = 0
        # self.mode = cfg.mode
        self.tb = None
        self.sample_pattern = cfg.sample

        self.loss_record = [10000, 0] # for early stopping condition

        self.device = torch.device("cuda:0")

    @property
    def _trainable_networks(self):
        raise NotImplementedError
        # return {'velocity': self.velocity_field}

    def create_optimizer(self, use_scheduler=True, gamma=0.1, patience=500, min_lr=1e-8):
        self.loss_record = [10000, 0]
        # optimizer: use only one optimizer?
        param_list = []
        for net in self._trainable_networks.values():
            param_list.append({"params": net.parameters(), "lr": self.cfg.lr})
        self.optimizer = torch.optim.Adam(param_list)
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=gamma, 
            min_lr=min_lr, patience=patience, verbose=True) if use_scheduler else None

    def create_tb(self, name, overwrite=True):
        """create tensorboard log"""
        self.log_path = os.path.join(self.cfg.log_dir, name)
        if os.path.exists(self.log_path) and overwrite:
            shutil.rmtree(self.log_path, ignore_errors=True)
        return SummaryWriter(self.log_path)

    @abstractmethod
    def step(self):
        raise NotImplementedError

    def update_network(self, loss_dict):
        """update network by back propagation"""
        loss = sum(loss_dict.values())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        if self.scheduler is not None:
            self.scheduler.step(loss_dict['main'])
        self.update_loss_record(loss_dict['main'].item())
    
    def update_loss_record(self, loss_val):
        if loss_val < self.loss_record[0]:
            self.loss_record = [loss_val, 0]
        else:
            self.loss_record[1] += 1
    
    def _set_require_grads(self, model, require_grad):
        for p in model.parameters():
            p.requires_grad_(require_grad)
    
    def save_ckpt(self, name=None):
        """save checkpoint during training for future restore"""
        if name is None:
            save_path = os.path.join(self.cfg.model_dir, f"ckpt_step_t{self.timestep:03d}.pth")
        else:
            save_path = os.path.join(self.cfg.model_dir, f"ckpt_{name}.pth")

        save_dict = {}
        for name, net in self._trainable_networks.items():
            save_dict.update({f'net_{name}': net.cpu().state_dict()})
            net.cuda()
        save_dict.update({'timestep': self.timestep})

        torch.save(save_dict, save_path)
    
    def load_ckpt(self, name):
        """load saved checkpoint"""
        if type(name) is int:
            load_path = os.path.join(self.cfg.model_dir, f"ckpt_step_t{name:03d}.pth")
        else:
            load_path = os.path.join(self.cfg.model_dir, f"ckpt_{name}.pth")
        checkpoint = torch.load(load_path)

        for name, net in self._trainable_networks.items():
            net.load_state_dict(checkpoint[f'net_{name}'])
        self.timestep = checkpoint['timestep']

    @classmethod
    def _training_loop(cls, func):
        """a decorator function that warps a function inside a training loop"""
        tag = func.__name__
        def loop(self, *args, **kwargs):
            pbar = tqdm(range(self.max_n_iters))
            self.tb.train_iter = 0
            for i in pbar:
                loss_dict = func(self, *args, **kwargs)
                self.update_network(loss_dict)

                loss_value = {k: v.item() for k, v in loss_dict.items()}

                self.tb.add_scalars(tag, loss_value, global_step=i)
                self.tb.train_iter += 1
                pbar.set_description(f"{tag}[{self.timestep}]")
                pbar.set_postfix(loss_value)

                if self.cfg.early_stop and self.optimizer.param_groups[0]['lr'] <= 1.1e-8 and self.loss_record[1] >= 500:
                    tqdm.write(f"early stopping at iteration {i}")
                    break
        return loop

    def sample_in_training(self, resolution):
        samples = []
        for s in self.sample_pattern:
            if s == 'random':
                random_samples = sample_random(resolution ** self.dim, self.dim, device=self.device).requires_grad_(True)
                samples.append(random_samples)
            elif s == 'uniform':
                uniform_samples = sample_uniform(resolution, self.dim, device=self.device).requires_grad_(True)
                samples.append(uniform_samples)
            else:
                raise NotImplementedError
        return samples