import tensorflow as tf
from src.models.templates import AbstractExperimentModel
from src.models.utils import density_map_generator, build_rank_head
from src.models.augment import make_augmentor

class NoDmapRank(AbstractExperimentModel):
    def __init__(self, weights, experiment_manager, **kwargs):
        """-------------------------------[INIT]------------------------"""
        super().__init__(weights, experiment_manager, **kwargs)
        self.build(weights)
        self.training_loss = {"rank_G_L": []}

    def build(self, weights):
        """-------------------------------[BUILD]-----------------------"""
        self.Genr = density_map_generator(weights, **self.kwargs)
        self.Rank = build_rank_head()
        self.Augm = make_augmentor(brightness = 0.10, contrast=0.3,
                                        hue=0.5, zoomout=((0, 0.40), None, "constant"),
                                        rotate=0.05, flip="horizontal", shotnoise=0.07)

    def register_checkpoint(self, path, save_freq):
        super().register_checkpoint(path, save_freq)
        for model, name in zip([self.Genr, self.Rank], ["G-model", "R-model"]):
            self.checkpoint.register(model, name)

        for name, opt in self.opt_dict.items():
            self.checkpoint.register(opt, "{}-opt".format(name))

    def basic_step(self, img_i, img_j, rank_ij):
        """-------------------------------[BASIC_STEP]------------------"""
        dmap_i     = self.Genr(self.Augm(img_i, training=True), training=True)
        dmap_j     = self.Genr(self.Augm(img_j, training=True), training=True)
        rank_pred_ij = self.Rank([dmap_i, dmap_j], training=True)

        rank_G_L = self.loss_dict["G"]["rank"](rank_ij, rank_pred_ij)

        return rank_G_L

    @tf.function
    def train_step(self, batch, extra_batch=None):
        """-------------------------------[TRAIN_STEP]------------------"""
        (img_i, img_j), rank_ij = batch

        with tf.GradientTape() as G_tape:
            rank_G_L = self.basic_step(img_i, img_j, rank_ij)
            tot_G_L = rank_G_L

        G_grads = G_tape.gradient(tot_G_L,  self.Genr.trainable_variables)
        self.opt_dict["G"].apply_gradients(zip(G_grads, self.Genr.trainable_variables))

        return [rank_G_L]

    def get_sparsity(self, batch):
        return 0
