import tensorflow as tf
import numpy as np
import logging
logging.getLogger('tensorflow').setLevel(logging.ERROR)

class DIFFIL(tf.keras.Model):

    def __init__(self,
                 agent,
                 batch_size,
                 make_decoder,
                 make_encoder,
                 make_label,
                 make_label_frame,
                 make_fwgan,
                 expert_buffer,
                 source_random_buffer=None,
                 target_random_buffer=None,
                 feature_size=32,
                 recon_loss = 1,
                 feature_consistency_loss = 1,
                 disc_loss = 50,
                 gen_loss = 10,
                 seq_label_source = 10,
                 seq_label_target = 0.001,
                 frame_label_loss = 10,
                 sampler = None):

        super(DIFFIL, self).__init__()
        self._source_decoder, self._target_decoder = make_decoder()
        self._encoder = make_encoder()
        self._feature_wgan = make_fwgan()
        self._label_net = make_label()
        self._label_net_frame = make_label_frame()

        self.agent = agent
        self._se_buff = expert_buffer

        self.sampler = sampler

        self._sr_buff = source_random_buffer
        self._tr_buff = target_random_buffer
        self.feature_size = feature_size

        self.discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
        self.model_optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

        self.mse_loss_batch = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)

        self.recon_scale = recon_loss
        self.fcon_scale = feature_consistency_loss
        self.disc_loss = disc_loss
        self.gen_loss = gen_loss
        self.seq_label_loss_source = seq_label_source
        self.seq_label_loss_target = seq_label_target
        self.frame_label_loss = frame_label_loss

        self._batch_size = batch_size

    def call(self, inputs):
        pre_s = self._encoder(inputs)
        recon_s = self._source_decoder(pre_s)
        recon_t = self._target_decoder(pre_s)
        label = self._label_net(pre_s)
        label_frame = self._label_net_frame(pre_s)
        featurewgan = self._feature_wgan(pre_s)

        return 0

    @staticmethod
    def _ce_gan_loss_custom(labels, probs, lb):
        return tf.losses.binary_crossentropy(labels, probs, label_smoothing=lb)

    def training_WGAN(self, model_data):

        GP_scale = 10
        batch = self._batch_size

        delta = tf.random.uniform(shape=[2*batch,1], minval=0, maxval=1, dtype='float32')

        feature = self._encoder(model_data)

        feature_source = feature[:2 * batch, ]
        feature_target = feature[2 * batch:, ]

        with tf.GradientTape() as d1_tape:

            interp = delta * feature_source + (1 - delta) * feature_target

            with tf.GradientTape() as tape:
                tape.watch(interp)
                predictions = self._feature_wgan(interp)

                gradients = tape.gradient(predictions, [interp])[0]
                grad_norms = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=[1]) + 1e-9)
                gradient_penalty = tf.reduce_mean(tf.square(grad_norms - 1))

            source_pred = self._feature_wgan(feature_source)
            target_pred = self._feature_wgan(feature_target)

            discriminator_loss = -tf.reduce_mean(source_pred) + tf.reduce_mean(target_pred)
            disc_loss = self.disc_loss*discriminator_loss + GP_scale * gradient_penalty

        d1_gradients = d1_tape.gradient(disc_loss, self._feature_wgan.trainable_variables)

        del d1_tape
        self.discriminator_optimizer.apply_gradients(zip(d1_gradients, self._feature_wgan.trainable_variables))

        return disc_loss,grad_norms,gradient_penalty

    def training_Model(self, model_data, label_data, dense_timesteps):

        batch = self._batch_size

        true_label = np.concatenate([tf.ones([batch, 1]), tf.zeros([3*batch, 1])],axis=0)

        timestep_double = np.array(dense_timesteps[:batch], dtype=np.float64).reshape(batch, 1)
        modif_label = timestep_double*tf.ones([batch, 1])
        timelabel = np.concatenate([modif_label, tf.zeros([3*batch, 1])],axis=0)

        with tf.GradientTape() as g_tape:

            feature = self._encoder(model_data)
            feature_label = self._encoder(label_data)

            recon_by_target = self._target_decoder(feature)
            recon_by_source = self._source_decoder(feature)

            recon_by_target_sg = self._target_decoder(tf.stop_gradient(feature))
            recon_by_source_sg = self._source_decoder(tf.stop_gradient(feature))

            feature_recon_by_source_sg = self._encoder(recon_by_source_sg[2*batch:])
            feature_recon_by_target_sg = self._encoder(recon_by_target_sg[:2*batch])

            feature_consistency_loss = self.mse_loss_batch(tf.stop_gradient(feature), tf.concat([feature_recon_by_target_sg,feature_recon_by_source_sg],axis=0))

            sequence_label_estim = self._label_net(feature_label)

            recon_loss = self.mse_loss_batch(model_data, tf.concat([recon_by_source[:2*self._batch_size], recon_by_target[2*self._batch_size:]],axis=0))
            sequence_label_loss = self._ce_gan_loss_custom(true_label,sequence_label_estim, 0)

            frame_label_estim = self._label_net_frame(tf.stop_gradient(feature_label))
            frame_label_loss = self._ce_gan_loss_custom(timelabel, frame_label_estim, 0)

            total_fwgan = self._feature_wgan(feature)

            source_fwgan = total_fwgan[:2 * batch]

            target_fwgan = total_fwgan[2 * batch:]

            reconstruction_scale = 65536 * self.recon_scale
            fcon_scale = self.fcon_scale
            generator_scale = self.gen_loss

            gen_loss = tf.reduce_mean(source_fwgan) - tf.reduce_mean(target_fwgan)

            label_loss = self.seq_label_loss_source *tf.reduce_mean(sequence_label_loss[:self._batch_size])\
                         + self.seq_label_loss_source *tf.reduce_mean(sequence_label_loss[self._batch_size:2 * self._batch_size]) \
                         + self.seq_label_loss_target *tf.reduce_mean(sequence_label_loss[2 * self._batch_size:3*self._batch_size])\
                         + self.seq_label_loss_target *tf.reduce_mean(sequence_label_loss[3 * self._batch_size:4*self._batch_size])\
                         + self.frame_label_loss *tf.reduce_mean(frame_label_loss[:self._batch_size])\
                         + self.frame_label_loss *tf.reduce_mean(frame_label_loss[self._batch_size:2 * self._batch_size])

            g_loss = reconstruction_scale *tf.reduce_mean(recon_loss) + fcon_scale * tf.reduce_mean(feature_consistency_loss) + generator_scale * gen_loss + label_loss

        g_gradients_next = g_tape.gradient(g_loss,
                                           self._encoder.trainable_variables
                                           + self._source_decoder.trainable_variables + self._target_decoder.trainable_variables
                                           + self._label_net.trainable_variables + self._label_net_frame.trainable_variables)

        del g_tape

        self.model_optimizer.apply_gradients(
            zip(g_gradients_next, self._encoder.trainable_variables
                + self._source_decoder.trainable_variables + self._target_decoder.trainable_variables+ self._label_net.trainable_variables+ self._label_net_frame.trainable_variables))

        return 65536*recon_loss, feature_consistency_loss, gen_loss, sequence_label_loss,frame_label_loss

    def process_data(self, a, b, c, d, init_num):
        def split_and_shuffle(data1, data2, init_num):
            data1_main, data1_tail = data1[:-init_num], data1[-init_num:]
            data2_main, data2_tail = data2[:-init_num], data2[-init_num:]

            combined = np.concatenate([data1_main, data2_main])
            np.random.shuffle(combined)

            split_idx = len(combined) // 2
            part1 = combined[:split_idx]
            part2 = combined[split_idx:]

            part1 = np.concatenate([part1, data1_tail])
            part2 = np.concatenate([part2, data2_tail])

            return part1, part2

        shuffled_source1, shuffled_source2 = split_and_shuffle(a, b, init_num)
        shuffled_target1, shuffled_target2 = split_and_shuffle(c, d, init_num)

        return shuffled_source1, shuffled_source2, shuffled_target1, shuffled_target2

    def train(self, agent_buffer, l_batch_size=128, l_updates=1, l_act_delay=1,
              d_updates=1):

        #WGAN Discriminator Training Ratio
        training_ratio = 5

        for i in range(d_updates):

            init_num = int(self._batch_size // 32)

            # Sampling from Buffer
            tl_batch, tl_timestep_list = agent_buffer.get_balance_batch_with_step(self._batch_size, init_num)
            tl_combine_dense, tl_timesteps = tl_batch[0]['ims'], tl_timestep_list[0]
            se_batch, se_timestep_list = self._se_buff.get_balance_batch_with_step(self._batch_size, init_num)
            se_combine_dense, se_timesteps = se_batch[0]['ims'], se_timestep_list[0]
            tr_batch, tr_timestep_list = self._tr_buff.get_balance_batch_with_step(self._batch_size, init_num)
            tr_combine, tr_timesteps = tr_batch[0]['ims'], tr_timestep_list[0]
            sr_batch, sr_timestep_list = self._sr_buff.get_balance_batch_with_step(self._batch_size, init_num)
            sr_combine, sr_timesteps = sr_batch[0]['ims'], sr_timestep_list[0]

            # Shuffle Source, Target Data
            source_combine_cluster,source_combine_cluster2,target_combine_cluster,target_combine_cluster2  \
                = self.process_data(se_combine_dense, sr_combine, tl_combine_dense, tr_combine,init_num)

            label_data = tf.concat([se_combine_dense, sr_combine, tl_combine_dense, tr_combine],axis=0)

            model_data = np.concatenate([source_combine_cluster, source_combine_cluster2, target_combine_cluster, target_combine_cluster2], axis=0)

            dense_timesteps = tf.concat([se_timesteps, sr_timesteps, tl_timesteps, tr_timesteps] , axis=0)

            # Training WGAN Discriminator
            feat_discriminator_loss, feat_grad_norms, feat_gradient_penalty= \
                self.training_WGAN(model_data)

            # Training Encoder,Decoder, Label Discriminator
            if (i % training_ratio) == 0:
                recon_loss, feature_consistency_loss, gen_loss, sequence_label_loss,frame_label_loss = \
                        self.training_Model(model_data, label_data, dense_timesteps)

        print("=*"*20)
        print("Update reward estimation")
        agent_buffer.update_reward()
        print("=*"*20)

        print("Start SAC training")
        self.agent.train(agent_buffer, l_batch_size, l_updates, l_act_delay)
        print("SAC training complete")

        return 0

