import numpy as np

import tensorflow as tf
tf.compat.v1.disable_eager_execution()

from opelab.core.baseline import Baseline
    

class IPSKernelTF(Baseline):
    
    def __init__(self, obs_dim: int, w_hidden, widths: float | None=None,
                 learning_rate: float=0.001, reg_weight: float=0.0,
                 batch_size: int=256, iters: int=20000, epsilon: float=0.0):
        self.batch_size = batch_size
        self.iters = iters
        self.epsilon = epsilon
        self.widths = widths
        self.obs_dim = obs_dim
        self.w_hidden = w_hidden
        self.reg_weight = reg_weight
        self.learning_rate = learning_rate
    
    def _compile(self):
        
        # place holder
        self.state = tf.compat.v1.placeholder(tf.float32, [None, self.obs_dim])
        self.med_dist = tf.compat.v1.placeholder(tf.float32, [])
        self.next_state = tf.compat.v1.placeholder(tf.float32, [None, self.obs_dim])

        self.state2 = tf.compat.v1.placeholder(tf.float32, [None, self.obs_dim])
        self.next_state2 = tf.compat.v1.placeholder(tf.float32, [None, self.obs_dim])
        self.policy_ratio = tf.compat.v1.placeholder(tf.float32, [None])
        self.policy_ratio2 = tf.compat.v1.placeholder(tf.float32, [None])
                    
        # density ratio for state and next state
        w = self.state_to_w(self.state, self.obs_dim, self.w_hidden)
        w_next = self.state_to_w(self.next_state, self.obs_dim, self.w_hidden)
        w2 = self.state_to_w(self.state2, self.obs_dim, self.w_hidden)
        w_next2 = self.state_to_w(self.next_state2, self.obs_dim, self.w_hidden)
        norm_w = tf.reduce_mean(w)
        norm_w_next = tf.reduce_mean(w_next)
        norm_w2 = tf.reduce_mean(w2)
        norm_w_next2 = tf.reduce_mean(w_next2)
        self.output = w

        # calculate loss function
        x = w * self.policy_ratio / norm_w - w_next / norm_w_next
        x2 = w2 * self.policy_ratio2 / norm_w2 - w_next2 / norm_w_next2

        diff_xx = tf.expand_dims(self.next_state, 1) - tf.expand_dims(self.next_state2, 0)
        K_xx = tf.exp(-tf.reduce_sum(tf.square(diff_xx), axis=-1) / (2.0 * self.med_dist * self.med_dist))
        loss_xx = tf.matmul(tf.matmul(tf.expand_dims(x, 0), K_xx), tf.expand_dims(x2, 1))
        self.loss = tf.squeeze(loss_xx) / tf.reduce_sum(K_xx)
        self.reg_loss = self.reg_weight * tf.reduce_sum(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES, 'w'))
        self.train_op = tf.compat.v1.train.AdamOptimizer(self.learning_rate).minimize(self.loss + self.reg_loss)

    def state_to_w(self, state, obs_dim, hidden_dim_dr):
        with tf.compat.v1.variable_scope('w', reuse=tf.compat.v1.AUTO_REUSE):
            
            # hidden layers
            pdim, pin = obs_dim, state
            for i, dim in enumerate(hidden_dim_dr):
                W = tf.compat.v1.get_variable(f'W{i}', initializer=tf.compat.v1.random.normal(shape=[pdim, dim]))
                b = tf.compat.v1.get_variable(f'b{i}', initializer=tf.zeros([dim]))
                z = tf.matmul(pin, W) + b
                mean_z, var_z = tf.nn.moments(z, [0])
                scale_z = tf.compat.v1.get_variable(f'scale_z{i}', initializer=tf.ones([dim]))
                beta_z = tf.compat.v1.get_variable(f'beta_z{i}', initializer=tf.zeros([dim]))
                l = tf.tanh(tf.nn.batch_normalization(z, mean_z, var_z, beta_z, scale_z, 1e-10))
                pdim, pin = dim, l
                
            # output layer
            W = tf.compat.v1.get_variable('Wo', 
                                          initializer=0.01 * tf.compat.v1.random.normal(shape=[pdim, 1]), 
                                          regularizer=tf.keras.regularizers.l2(0.5))
            b = tf.compat.v1.get_variable('bo', 
                                          initializer=tf.zeros([1]), 
                                          regularizer=tf.keras.regularizers.l2(0.5))
            z = tf.matmul(pin, W) + b
            return tf.math.log(1. + tf.exp(tf.squeeze(z)))
    
    def get_density_ratio(self, sess, states):
        return sess.run(self.output, feed_dict={ self.state: states })

    def train(self, sess, SASR, policy0, policy1):        
        S, SN, PI1, PI0, REW = [], [], [], [], []
        for tau in SASR:
            for state, action, next_state, reward in zip(
                tau['states'], tau['actions'], tau['next-states'], tau['rewards']):
                PI1.append(policy1.prob(state, action))
                PI0.append(policy0.prob(state, action))
                S.append(state)
                SN.append(next_state)
                REW.append(reward)
                
        # normalized        
        S = np.array(S)
        S_max = np.max(S, axis=0)
        S_min = np.min(S, axis=0)
        S = (S - S_min) / (S_max - S_min)
        SN = (np.array(SN) - S_min) / (S_max - S_min)

        PI1 = np.array(PI1)
        PI0 = np.array(PI0)
        REW = np.array(REW)
        N = S.shape[0]
        
        if self.widths is None:
            s = S[np.random.choice(N, 4096)]
            med_dist = np.median(np.sqrt(np.sum(np.square(s[None,:,:] - s[:, None,:]), axis=-1)))
            print(med_dist)
        else:
            med_dist = self.widths
        
        for i in range(self.iters):
            
            # evaluation
            if i % (self.iters // 10) == 0:
                subsamples1 = np.random.choice(N, self.batch_size)
                subsamples2 = np.random.choice(N, self.batch_size)                
                test_loss, reg_loss = sess.run([self.loss, self.reg_loss], feed_dict={
                    self.med_dist: med_dist,
                    self.state: S[subsamples1],
                    self.next_state: SN[subsamples1],
                    self.policy_ratio: (PI1[subsamples1] + self.epsilon) / (PI0[subsamples1] + self.epsilon),
                    self.state2: S[subsamples2],
                    self.next_state2: SN[subsamples2],
                    self.policy_ratio2: (PI1[subsamples2] + self.epsilon) / (PI0[subsamples2] + self.epsilon)
                })
                DENR = self.get_density_ratio(sess, S)
                T = DENR * PI1 / PI0
                est = np.sum(T * REW) / np.sum(T)
                print(f"iter {i} loss {test_loss} reg {reg_loss} estimate {est}")
            
            # training
            subsamples = np.random.choice(N, self.batch_size)
            sess.run(self.train_op, feed_dict={
                self.med_dist: med_dist,
                self.state: S[subsamples],
                self.next_state: SN[subsamples],
                self.policy_ratio: (PI1[subsamples] + self.epsilon) / (PI0[subsamples] + self.epsilon),
                self.state2: S[subsamples],
                self.next_state2: SN[subsamples],
                self.policy_ratio2: (PI1[subsamples] + self.epsilon) / (PI0[subsamples] + self.epsilon)
            })

    def evaluate(self, data, target, behavior, gamma, reward_estimator=None):
        with tf.compat.v1.Session() as sess:
            self._compile()
            sess.run(tf.compat.v1.global_variables_initializer())
            self.train(sess, data, behavior, target)
                
            S = []
            POLICY_RATIO = []
            REW = []
            for tau in data:
                for state, action, next_state, reward in zip(
                    tau['states'], tau['actions'], tau['next-states'], tau['rewards']):
                    POLICY_RATIO.append(target.prob(state, action) / behavior.prob(state, action))
                    S.append(state)
                    REW.append(reward)

            S = np.array(S)
            S_max = np.max(S, axis=0)
            S_min = np.min(S, axis=0)
            S = (S - S_min) / (S_max - S_min)
            POLICY_RATIO = np.array(POLICY_RATIO)
            REW = np.array(REW)
            
            DENR = self.get_density_ratio(sess, S)
            T = DENR * POLICY_RATIO
            result = np.sum(T * REW) / np.sum(T)
        del self.train_op
        return result


class IPSGANTF(Baseline):
    
    def __init__(self, obs_dim: int, w_hidden, f_hidden, 
                 learning_rate: float=0.001, gau: int=0, 
                 reg_weight: float=0.0, batch_size: int=256, iters: int=20000):
        self.batch_size = batch_size
        self.iters = iters
        self.obs_dim = obs_dim
        self.w_hidden = w_hidden
        self.learning_rate = learning_rate
        self.reg_weight = reg_weight
        self.gau = gau
    
    def _compile(self):
        
        # place holder
        self.state = tf.compat.v1.placeholder(tf.float32, [None, self.obs_dim])
        self.next_state = tf.compat.v1.placeholder(tf.float32, [None, self.obs_dim])
        self.policy_ratio = tf.compat.v1.placeholder(tf.float32, [None])
                    
        # density ratio for state and next state
        if self.gau == 0:
            w = self.state_to_w_batch_norm(self.state, self.obs_dim, self.w_hidden)
            w_next = self.state_to_w_batch_norm(self.next_state, self.obs_dim, self.w_hidden)
        else:
            w = self.state_to_w_gau_mix(self.state, self.obs_dim, self.w_hidden)
            w_next = self.state_to_w_gau_mix(self.next_state, self.obs_dim, self.w_hidden)
        norm_w = tf.reduce_mean(w)
        self.output = w
        
        f_next = self.state_to_f(self.next_state, self.obs_dim, self.f_hidden)
        norm_f = tf.sqrt(tf.reduce_mean(f_next * f_next)) + 1e-15

        # calculate loss function
        x = w * self.policy_ratio - w_next
        self.reg_loss = self.reg_weight * tf.reduce_sum(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES, 'w'))
        self.reg_loss_f = self.reg_weight * tf.reduce_sum(
            tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.REGULARIZATION_LOSSES, 'f'))
        self.loss = tf.reduce_mean(x * f_next) / (norm_w * norm_f)
        
        with tf.compat.v1.variable_scope('optimizer'):
            optimizer = tf.compat.v1.train.AdamOptimizer(self.learning_rate)
            f_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, 'f')
            w_vars = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TRAINABLE_VARIABLES, 'w')
            self.train_op_f = optimizer.minimize(-self.loss, var_list=f_vars)
            self.train_op_w = optimizer.minimize(self.loss + self.reg_loss, var_list=w_vars)

    def state_to_w_gau_mix(self, state, obs_dim, num_component):
        with tf.compat.v1.variable_scope('w', reuse=tf.compat.v1.AUTO_REUSE):
            mean_state, var_state = tf.nn.moments(state, [0])
            state = (state - mean_state) / tf.sqrt(var_state)
            initial_mu = tf.compat.v1.random_normal(shape=[num_component, obs_dim], stddev=np.sqrt(1.))
            dpi = self.gaussian_mixture(state, obs_dim, num_component, 0.0, 'dpi', initial_mu)
            dpi0 = self.gaussian_mixture(state, obs_dim, num_component, 0.0, 'dpi0', initial_mu)
            return (dpi + 1e-15) / (dpi0 + 1e-15)

    def gaussian_mixture(self, state, obs_dim, num_component, std_min, scope, initial_mu):
        with tf.compat.v1.variable_scope(scope, reuse=tf.compat.v1.AUTO_REUSE):
            alpha = tf.compat.v1.get_variable('alpha', initializer=tf.zeros([num_component]))
            mu = tf.compat.v1.get_variable('mu', initializer=initial_mu)
            log_sigma = tf.compat.v1.get_variable('log_sigma', initializer=tf.zeros([num_component]))
            log_prob = alpha - tf.reduce_sum(
                tf.square(tf.expand_dims(state, 1) - tf.expand_dims(mu, 0)), axis=-1) / (20 * tf.exp(log_sigma) + std_min)
            prob = tf.reduce_mean(tf.exp(log_prob), axis=-1)
            return tf.squeeze(prob)

    def state_to_w_batch_norm(self, state, obs_dim, hidden_dim_dr):
        with tf.compat.v1.variable_scope('w', reuse=tf.compat.v1.AUTO_REUSE):

            # First layer
            W1 = tf.compat.v1.get_variable('W1', initializer=tf.compat.v1.random_normal(shape=[obs_dim, hidden_dim_dr]))
            b1 = tf.compat.v1.get_variable('b1', initializer=tf.zeros([hidden_dim_dr]))
            z1 = tf.matmul(state, W1) + b1
            mean_z1, var_z1 = tf.nn.moments(z1, [0])
            scale_z1 = tf.compat.v1.get_variable('scale_z1', initializer=tf.ones([hidden_dim_dr]))
            beta_z1 = tf.compat.v1.get_variable('beta_z1', initializer=tf.zeros([hidden_dim_dr]))
            l1 = tf.tanh(tf.nn.batch_normalization(z1, mean_z1, var_z1, beta_z1, scale_z1, 1e-10))

            # Second layer
            W2 = tf.compat.v1.get_variable('W2', 
                                           initializer=0.1 * tf.compat.v1.random_normal(shape=[hidden_dim_dr, 1]), 
                                           regularizer=tf.keras.regularizers.l2(1.))
            b2 = tf.compat.v1.get_variable('b2', 
                                           initializer=tf.zeros([1]), 
                                           regularizer=tf.keras.regularizers.l2(1.))
            z2 = tf.matmul(l1, W2) + b2
            return tf.math.log(1. + tf.exp(tf.squeeze(z2)))

    def state_to_w(self, state, obs_dim, hidden_dim_dr):
        with tf.compat.v1.variable_scope('w', reuse=tf.compat.v1.AUTO_REUSE):

            # First layer
            W1 = tf.compat.v1.get_variable('W1', 
                                           initializer=0.5 * tf.compat.v1.random_normal(shape=[obs_dim, hidden_dim_dr]), 
                                           regularizer=tf.keras.regularizers.l2(1.))
            b1 = tf.compat.v1.get_variable('b1', 
                                           initializer=tf.zeros([hidden_dim_dr]), 
                                           regularizer=tf.keras.regularizers.l2(1.))
            z1 = tf.matmul(state, W1) + b1
            l1 = tf.tanh(z1)

            # Second layer
            W2 = tf.compat.v1.get_variable('W2', 
                                           initializer=0.5 * tf.compat.v1.random_normal(shape=[hidden_dim_dr, 1]), 
                                           regularizer=tf.keras.regularizers.l2(1.))
            b2 = tf.compat.v1.get_variable('b2', 
                                           initializer=tf.zeros([1]), 
                                           regularizer=tf.keras.regularizers.l2(1.))
            z2 = tf.matmul(l1, W2) + b2
            return tf.math.log(1. + tf.exp(tf.squeeze(z2)))

    def state_to_f(self, state, obs_dim, hidden_dim_dr):
        with tf.compat.v1.variable_scope('f', reuse=tf.compat.v1.AUTO_REUSE):

            W4 = tf.compat.v1.get_variable('W4', 
                                           initializer=tf.compat.v1.random_normal(shape=[obs_dim, hidden_dim_dr]), 
                                           regularizer=tf.keras.regularizers.l2(1.))
            b4 = tf.compat.v1.get_variable('b4', 
                                           initializer=tf.zeros([hidden_dim_dr]), 
                                           regularizer=tf.keras.regularizers.l2(1.))
            z1 = tf.matmul(state, W4) + b4
            mean_z1, var_z1 = tf.nn.moments(z1, [0])
            scale_z1 = tf.compat.v1.get_variable('scale_z1', initializer=tf.ones([hidden_dim_dr]))
            beta_z1 = tf.compat.v1.get_variable('beta_z1', initializer=tf.zeros([hidden_dim_dr]))
            l1 = tf.tanh(tf.nn.batch_normalization(z1, mean_z1, var_z1, beta_z1, scale_z1, 1e-10))

            W5 = tf.compat.v1.get_variable('W5', 
                                           initializer=tf.compat.v1.random_normal(shape=[hidden_dim_dr, 1]), 
                                           regularizer=tf.keras.regularizers.l2(1.))
            b5 = tf.compat.v1.get_variable('b5', 
                                           initializer=tf.zeros([1]), 
                                           regularizer=tf.keras.regularizers.l2(1.))
            z2 = tf.matmul(l1, W5) + b5
            return tf.squeeze(z2)

    def get_density_ratio(self, sess, states):
        return sess.run(self.output, feed_dict={ self.state: states })
        
    def train(self, sess, SASR, policy0, policy1):
        S, SN, POLICY_RATIO, REW = [], [], [], []
        for tau in SASR:
            for state, action, next_state, reward in zip(
                tau['states'], tau['actions'], tau['next-states'], tau['rewards']):
                S.append(state)
                SN.append(next_state)
                POLICY_RATIO.append(policy1.prob(state, action) / policy0.prob(state, action))
                REW.append(reward)        
        S = np.array(S)
        SN = np.array(SN)
        POLICY_RATIO = np.array(POLICY_RATIO)
        REW = np.array(REW)

        N = S.shape[0]
        for i in range(self.iters):
            if i % (self.iters // 20) == 0:
                test_loss, test_reg_loss = sess.run([self.loss, self.reg_loss], feed_dict={
                    self.state: S,
                    self.next_state: SN,
                    self.policy_ratio: POLICY_RATIO
                })
                DENR = self.get_density_ratio(sess, S)
                T = DENR * POLICY_RATIO
                est = np.sum(T * REW) / np.sum(T)
                print(f"iter {i} loss {test_loss} reg {test_reg_loss} estimate {est}")
                
            subsamples = np.random.choice(N, self.batch_size)
            s = S[subsamples]
            sn = SN[subsamples]
            policy_ratio = POLICY_RATIO[subsamples]    
            
            for _ in range(5):
                sess.run(self.train_op_f, feed_dict={
                    self.state: s,
                    self.next_state: sn,
                    self.policy_ratio: policy_ratio
                })            
            sess.run(self.train_op_w, feed_dict={
                self.state: s,
                self.next_state: sn,
                self.policy_ratio: policy_ratio
            })

    def evaluate(self, data, target, behavior, gamma):
        self._compile()
        with tf.compat.v1.Session() as sess:
            sess.run(tf.compat.v1.global_variables_initializer())
            self.train(sess, data, behavior, target)
            
            S, REW, policy_ratio = [], [], []
            for tau in data:
                for state, action, next_state, reward in zip(
                    tau['states'], tau['actions'], tau['next-states'], tau['rewards']):
                    S.append(state)
                    REW.append(reward)
                    policy_ratio.append(target.prob(state, action) / behavior.prob(state, action))
                    
            w = self.get_density_ratio(sess, np.array(S))
            policy_ratio = np.array(policy_ratio)
            REW = np.array(REW)
            T = w * policy_ratio
            return np.sum(T * REW) / np.sum(T)
