from src.models.utils              import density_map_generator, base_discriminator, build_rank_head, SumLayer
from src.models.augment            import make_augmentor
from src.losses                    import gen_adversarial_loss, disc_adversarial_loss
from src.experiment.utils          import CheckPointTracker
from tensorflow.keras              import Model, Sequential
from tensorflow.keras.layers       import *
from datetime                      import datetime as dt
import tensorflow                  as tf
import numpy                       as np

class GAN(object):
    def __init__(self, weights, experiment_manager, **kwargs):
        self.compiled = False
        self.training_loss = {"dummy": []}
        self.experiment_manager = experiment_manager
        self.checkpoint = None
        self.epochs_completed = 0
        self.kwargs = kwargs

    def compile(self, opt_dict, loss_dict, loss_w_dict):
        """-------------------------------[COMPILE]---------------------"""
        self.opt_dict  = opt_dict
        self.loss_dict = loss_dict
        self.loss_w_dict = loss_w_dict
        self.compiled = True

    def get_sparsity(self, batch):
        raise NotImplementedError("Must implement get_sparsity method")

    def register_checkpoint(self, path, save_freq):
        if self.compiled == False:
            raise NotImplementedError("Model needs to be compiled before a checkpoint can be restored.")
        self.checkpoint = CheckPointTracker(path, save_freq)

    def train_step(self, batch, extra_batch=None):
        raise NotImplementedError("Must implement train_step method")

    def modify_checkpoint(self, mod_str):
        if self.checkpoint == None:
            raise NotImplementedError("No registered checkpoint. Set checkpoint using the method register_checkpoint(path)")
        elif self.compiled == False:
            raise NotImplementedError("Model needs to be compiled before a checkpoint can be restored.")
        elif mod_str == "restore":
            self.epochs_completed = self.checkpoint.restore()
        elif mod_str == "update":
            self.checkpoint.update()
        else:
            raise ValueError("Unknown modification {}".format(mod_str))

    def print_epoch(self, epoch, epoch_loss, sparsity, diff_t):
        print("Elapsed: {}, Epoch: {}, ".format(diff_t.total_seconds(), epoch), end="")
        for key, val in epoch_loss.items():
            self.training_loss[key].append(float(np.mean(val)))
            print("{}: {}, ".format(key, np.mean(val)), end="", flush=True)
        print("{}: {}, ".format("sparsity", sparsity), flush=True)

    def fit(self, train_ds, epochs, extra_ds=None):
        print("Fitting!")
        if not self.compiled:
            raise ValueError('You must first compile your model before calling fit!')

        self.experiment_manager.on_experiment_start()
        if extra_ds != None:
            extra_ds_iter = iter(extra_ds)
        for epoch in range(self.epochs_completed, epochs):
            self.experiment_manager.on_epoch_start()
            epoch_loss = {key: [] for key in self.training_loss.keys()}
            epoch_start_t = dt.now()
            for batch in train_ds:
                if extra_ds == None:
                    batch_loss = self.train_step(batch)
                else:
                    extra_batch = next(extra_ds_iter)
                    batch_loss = self.train_step(batch, extra_batch)
                for key, loss in zip(epoch_loss.keys(), batch_loss):
                    epoch_loss[key].append(loss)
            epoch_end_t = dt.now()
            sparsity = self.get_sparsity(batch)

            self.modify_checkpoint("update")
            self.print_epoch(epoch, epoch_loss, sparsity, epoch_end_t - epoch_start_t)
            if self.experiment_manager.on_epoch_end():
                break
        self.experiment_manager.on_experiment_end()
        return self.training_loss

class DMapCountGAN(GAN):
    def __init__(self, weights, experiment_manager, **kwargs):
        """-------------------------------[INIT]------------------------"""
        super().__init__(weights, experiment_manager, **kwargs)
        self.build(weights)
        self.training_loss = {"advr_G_L": [], "count_G_L": [], "advr_D_L": []}

    def build(self, weights):
        """-------------------------------[BUILD]-----------------------"""
        self.Disc = base_discriminator()
        self.Genr = density_map_generator(weights, **self.kwargs)
        self.Augm = make_augmentor(brightness = 0.10, contrast=0.3,
                                        hue=0.5, zoomout=((0, 0.40), None, "constant"),
                                        rotate=0.05, flip="horizontal", shotnoise=0.07)
        self.Suml = Sequential([Flatten(),
                                SumLayer(1, True)])

    def register_checkpoint(self, path, save_freq):
        super().register_checkpoint(path, save_freq)
        for model, name in zip([self.Genr, self.Disc, self.Suml], ["G-model", "D-model", "S-model"]):
            self.checkpoint.register(model, name)

        for name, opt in self.opt_dict.items():
            self.checkpoint.register(opt, "{}-opt".format(name))

    def basic_step(self, dmap_real, img_i, count_i):
        """-------------------------------[BASIC_STEP]------------------"""
        dmap_fake       = self.Genr(self.Augm(img_i, training=True), training=True)
        count_pred_i    = self.Suml(dmap_fake)

        adv_D_pred_real = self.Disc(dmap_real, training=True)
        adv_D_pred_fake = self.Disc(dmap_fake, training=True)

        advr_G_L        = gen_adversarial_loss(self.loss_dict["G"]["advr"], adv_D_pred_fake)
        count_G_L       = self.loss_dict["G"]["count"](count_i, count_pred_i)

        advr_D_L        = disc_adversarial_loss(self.loss_dict["D"]["advr"],
                                                adv_D_pred_real, adv_D_pred_fake)
        return advr_G_L, count_G_L, advr_D_L

    @tf.function
    def train_step(self, batch, extra_batch=None):
        """-------------------------------[TRAIN_STEP]------------------"""
        img_i, count_i, dmap_real = batch

        with tf.GradientTape() as G_tape, tf.GradientTape() as D_tape:
            advr_G_L, count_G_L, advr_D_L = self.basic_step(dmap_real, img_i, count_i)
            tot_D_L = self.loss_w_dict["D"]["advr"] * advr_D_L
            tot_G_L = self.loss_w_dict["G"]["advr"] * advr_G_L + self.loss_w_dict["G"]["count"] * count_G_L

        G_grads = G_tape.gradient(tot_G_L,  self.Genr.trainable_variables)
        D_grads = D_tape.gradient(advr_D_L, self.Disc.trainable_variables)
        self.opt_dict["G"].apply_gradients(zip(G_grads, self.Genr.trainable_variables))
        self.opt_dict["D"].apply_gradients(zip(D_grads, self.Disc.trainable_variables))

        return advr_G_L, count_G_L, advr_D_L


    def get_sparsity(self, batch):
        return 0

class DMapGAN(GAN):
    def __init__(self, weights, experiment_manager, **kwargs):
        super().__init__(weights, experiment_manager, **kwargs)
        self.build(weights)
        self.training_loss = {"advr_G_L": [], "rank_G_L": [], "advr_D_L": []}

    def build(self, weights):
        """-------------------------------[BUILD]-----------------------"""
        self.Disc = base_discriminator()
        self.Genr = density_map_generator(weights, **self.kwargs)
        self.Rank = build_rank_head()
        self.Augm = make_augmentor(brightness = 0.10, contrast=0.3,
                                        hue=0.5, zoomout=((0, 0.40), None, "constant"),
                                        rotate=0.05, flip="horizontal", shotnoise=0.07)

    def register_checkpoint(self, path, save_freq):
        super().register_checkpoint(path, save_freq)
        for model, name in zip([self.Genr, self.Disc, self.Rank], ["G-model", "D-model", "R-model"]):
            self.checkpoint.register(model, name)

        for name, opt in self.opt_dict.items():
            self.checkpoint.register(opt, "{}-opt".format(name))

    def basic_step(self, dmap_real, img_i, img_j, rank_ij):
        """-------------------------------[BASIC_STEP]------------------"""
        dmap_fake_i     = self.Genr(self.Augm(img_i, training=True), training=True)
        dmap_fake_j     = self.Genr(self.Augm(img_j, training=True), training=True)
        rank_pred_ij    = self.Rank([dmap_fake_i, dmap_fake_j], training=True)

        adv_D_pred_real = self.Disc(dmap_real, training=True)
        adv_D_pred_fake = concatenate([self.Disc(dmap_fake_i, training=True),
                                       self.Disc(dmap_fake_j, training=True)], 0)

        advr_G_L        = gen_adversarial_loss(self.loss_dict["G"]["advr"], adv_D_pred_fake)
        rank_G_L        = self.loss_dict["G"]["rank"](rank_ij, rank_pred_ij)

        advr_D_L        = disc_adversarial_loss(self.loss_dict["D"]["advr"], adv_D_pred_real, adv_D_pred_fake)
        return advr_G_L, rank_G_L, advr_D_L

    @tf.function
    def train_step(self, batch, extra_batch=None):
        """-------------------------------[TRAIN_STEP]------------------"""
        dmap_real, img_i, img_j, rank_ij = batch

        with tf.GradientTape() as G_tape, tf.GradientTape() as D_tape:
            advr_G_L, rank_G_L, advr_D_L = self.basic_step(dmap_real, img_i, img_j, rank_ij)
            tot_D_L = self.loss_w_dict["D"]["advr"] * advr_D_L
            tot_G_L = self.loss_w_dict["G"]["advr"] * advr_G_L + self.loss_w_dict["G"]["rank"] * rank_G_L

        G_grads = G_tape.gradient(tot_G_L,  self.Genr.trainable_variables)
        D_grads = D_tape.gradient(advr_D_L, self.Disc.trainable_variables)
        self.opt_dict["G"].apply_gradients(zip(G_grads, self.Genr.trainable_variables))
        self.opt_dict["D"].apply_gradients(zip(D_grads, self.Disc.trainable_variables))

        return advr_G_L, rank_G_L, advr_D_L

    def get_sparsity(self, batch):
        return int(np.unique(np.round(self.Genr(tf.concat([batch[1], batch[2]], axis=0)).numpy(), 2),
                                 return_counts=True)[1][0])

class DMapNoRank(DMapGAN):
    def __init__(self, weights, experiment_manager, **kwargs):
        super().__init__(weights, experiment_manager, **kwargs)
        self.build(weights)
        self.training_loss = {"advr_G_L": [], "advr_D_L": []}

    @tf.function
    def train_step(self, batch, extra_batch=None):
        """-------------------------------[TRAIN_STEP]------------------"""
        dmap_real, img_i, img_j, rank_ij = batch

        with tf.GradientTape() as G_tape, tf.GradientTape() as D_tape:
            advr_G_L, rank_G_L, advr_D_L = self.basic_step(dmap_real, img_i, img_j, rank_ij)
            tot_D_L = self.loss_w_dict["D"]["advr"] * advr_D_L
            tot_G_L = self.loss_w_dict["G"]["advr"] * advr_G_L

        G_grads = G_tape.gradient(tot_G_L,  self.Genr.trainable_variables)
        D_grads = D_tape.gradient(advr_D_L, self.Disc.trainable_variables)
        self.opt_dict["G"].apply_gradients(zip(G_grads, self.Genr.trainable_variables))
        self.opt_dict["D"].apply_gradients(zip(D_grads, self.Disc.trainable_variables))

        return advr_G_L, advr_D_L

    def get_sparsity(self, batch):
        return 0

