import os
import time
import pickle
import gym
import numpy as np
import tensorflow as tf
from tf_agents.specs.tensor_spec import TensorSpec
from tqdm import tqdm
import wandb
import util
from networks import TanhMixtureNormalPolicy, TanhNormalPolicy, ValueNetwork, DeterministicPolicy

np.set_printoptions(precision=3, suppress=True)
import wandb
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
class OSDBC(tf.keras.layers.Layer):
    """Offline policy Optimization via Stationary DIstribution Correction Estimation (OptiDICE)"""

    def __init__(self, observation_spec, action_spec, config):
        super(OSDBC, self).__init__()

        self._gamma = config['gamma']
        self._policy_extraction = config['policy_extraction']
        self._env_name = config['env_name']
        self._total_iterations = config['total_iterations']
        self._warmup_iterations = config['warmup_iterations']
        self._hidden_sizes = config['hidden_sizes']
        self._batch_size = config['batch_size']
        self._alpha = config['alpha']
        self._f = config['f']
        self._lr = config['lr']
        self._v_l2_reg = config['v_l2_reg']
        self._lamb_scale = config['lamb_scale']
        self._target_entropy = config['target_entropy']


        self._iteration = tf.Variable(0, dtype=tf.int64, name='iteration', trainable=False)
        self._optimizers = dict()

        self._bc_loss = config['bc_loss']
        self._lower = config['lower']
        self._higher = config['higher']
        self._beta = config['beta']

        # create networks / variables for OSD
        self._v_network = ValueNetwork((observation_spec,), hidden_sizes=self._hidden_sizes, output_activation_fn=None, name='v')
        self._v_network.create_variables()
        self._optimizers['v'] = tf.keras.optimizers.Adam(self._lr)
        self._lamb_v = tf.Variable(0.0, dtype=tf.float32, name='lamb_v')
        self._optimizers['lamb_v'] = tf.keras.optimizers.Adam(self._lr)


        self._f_fn = lambda x: tf.where(x < 1, x * (tf.math.log(x + 1e-10) - 1) + 1, 0.5 * (x - 1) ** 2)
        self._f_prime_inv_fn = lambda x: tf.where(x < 0, tf.math.exp(tf.minimum(x, 0)), x + 1)
        self._g_fn = lambda x: tf.where(x < 0, tf.math.exp(tf.minimum(x, 0)) * (tf.minimum(x, 0) - 1) + 1, 0.5 * x ** 2)
        self._r_fn = lambda x: self._f_prime_inv_fn(x)
        self._log_r_fn = lambda x: tf.where(x < 0, x, tf.math.log(tf.maximum(x, 0) + 1))

        self._use_wandb = 0
        if self._use_wandb:
            wandb.init(project="osd-bc", group='osd-bc',
            name=config['env_name'], config=config)

        # policy
        # single network 
        if self._bc_loss == 'likelihood':
            self._policy_network = TanhNormalPolicy((observation_spec,), action_spec.shape[0], hidden_sizes=self._hidden_sizes,
                                            mean_range=config['mean_range'], logstd_range=config['logstd_range'])
            self._policy_network.create_variables()
            self._optimizers['policy'] = tf.keras.optimizers.Adam(self._lr)

            self._log_ent_coeff = tf.Variable(0.0, dtype=tf.float32, name='ent_coeff')
            self._optimizers['ent_coeff'] = tf.keras.optimizers.Adam(self._lr)
        else:
            self._policy_network = DeterministicPolicy((observation_spec,), action_spec.shape[0], hidden_sizes=self._hidden_sizes,
                                            mean_range=config['mean_range'])
            self._policy_network.create_variables()
            self._optimizers['policy'] = tf.keras.optimizers.Adam(self._lr)


            
        


    def v_loss(self, initial_v_values, e_v, w_v, f_w_v, result={}):
        # Compute v loss
        v_loss0 = (1 - self._gamma) * tf.reduce_mean(initial_v_values)
        v_loss1 = tf.reduce_mean(- self._alpha * f_w_v)
        v_loss2 = tf.reduce_mean(w_v * (e_v - self._lamb_v))
        v_loss3 = self._lamb_v
        v_loss4 = tf.reduce_mean(tf.square(e_v))
        v_loss = v_loss0 + v_loss1 + v_loss2 + v_loss3 + self._beta * v_loss4

        v_l2_norm = tf.linalg.global_norm(self._v_network.variables)

        if self._v_l2_reg is not None:
            v_loss += self._v_l2_reg * v_l2_norm

        result.update({
            'v_loss0': v_loss0,
            'v_loss1': v_loss1,
            'v_loss2': v_loss2,
            'v_loss3': v_loss3,
            'v_loss4': v_loss4,
            'v_loss': v_loss,
            'v_l2_norm': v_l2_norm,
            'w_v_mean':tf.reduce_mean(w_v),
            'w_v_max':tf.reduce_max(w_v, axis=0),
            'w_v_min': tf.reduce_min(w_v, axis=0)
        })

        return result

    def lamb_v_loss(self, e_v, w_v, f_w_v, result={}):
        # GenDICE regularization: E_D[w(s,a)] = 1
        lamb_v_loss = tf.reduce_mean(- self._alpha * f_w_v + w_v * (e_v - self._lamb_scale * self._lamb_v +self._beta * tf.square(e_v)) + self._lamb_v)
        result.update({
            'lamb_v_loss': lamb_v_loss,
            'lamb_v': self._lamb_v,
        })

        return result





    def policy_loss_new(self, observation, action,w_v, result={}):
        (_, _, sampled_action_log_prob, _, policy_dists), _ = self._policy_network((observation,))
        negative_entropy_loss = tf.reduce_mean(sampled_action_log_prob)
        a = tf.constant(1.)
        w_v = tf.clip_by_value(w_v,self._lower,self._higher)
        # w_v= tf.where(w_v > a, a, tf.constant(0.))
        optimal_action_log_prob, _ = self._policy_network.log_prob(policy_dists, action, is_pretanh_action=False)
        
        # policy_loss_new = - tf.reduce_mean(optimal_action_log_prob * (w_v/ws)) /(tf.reduce_mean(w_v/ws)+1e-10)    
        policy_loss_new = - tf.reduce_mean(optimal_action_log_prob * (w_v)) /(tf.reduce_mean(w_v)  +1e-10)
        result.update({
            'policy_loss_new': policy_loss_new,
        })

        ent_coeff = tf.exp(self._log_ent_coeff)
        policy_loss_new += ent_coeff * negative_entropy_loss

        ent_coeff_loss = - self._log_ent_coeff * (sampled_action_log_prob + self._target_entropy)

        result.update({
            'ent_coeff_loss': tf.reduce_mean(ent_coeff_loss),
            'ent_coeff': ent_coeff,
        })

        result.update({
            'policy_loss_new': policy_loss_new,
            # 'policy_l2_norm': policy_l2_norm,
            'negative_entropy_loss': negative_entropy_loss
        })
        return result    


    def policy_loss_mse(self, observation, action,w_v,  result={}):
        w_v = tf.clip_by_value(w_v,self._lower,self._higher)
        pred_action,_ = self._policy_network((observation,))
        policy_loss_mse= tf.reduce_sum(tf.square(tf.multiply(pred_action-action, tf.expand_dims(w_v, axis=1)))) /(tf.reduce_sum(w_v)  +1e-10)
        result.update({
            'policy_loss_mse': policy_loss_mse,
        })


        result.update({
            'policy_loss_mse': policy_loss_mse
        })
        return result  


    @tf.function
    def train_step(self, initial_observation, observation, action, reward, next_observation, terminal):
        with tf.GradientTape(persistent=True, watch_accessed_variables=False) as tape:
            # watch variables for gradient tracking
            tape.watch(self._policy_network.variables)
            tape.watch(self._v_network.variables)
            tape.watch(self._lamb_v)
            if self._bc_loss=='likelihood':
                tape.watch(self._log_ent_coeff)


            # Shared network values
            before = time.time()
            initial_v_values, _ = self._v_network((initial_observation,))
            print('observation', observation.shape)
            v_values, _ = self._v_network((observation,))
            next_v_values, _ = self._v_network((next_observation,))


            e_v = reward + (1 - terminal) * self._gamma * next_v_values - v_values 
            
            
            preactivation_v = (e_v - self._lamb_scale * self._lamb_v) / self._alpha
            w_v = self._r_fn(preactivation_v)
            f_w_v = self._g_fn(preactivation_v)
            loss_result = self.v_loss(initial_v_values, e_v, w_v, f_w_v, result={})


            loss_result = self.lamb_v_loss(e_v, w_v, f_w_v, result=loss_result)

            if self._bc_loss == 'likelihood':
                loss_result = self.policy_loss_new(observation, action, tf.stop_gradient(w_v),  result=loss_result)
            else:
                loss_result = self.policy_loss_mse(observation, action, tf.stop_gradient(w_v),  result=loss_result)


        v_grads = tape.gradient(loss_result['v_loss'], self._v_network.variables)
        self._optimizers['v'].apply_gradients(zip(v_grads, self._v_network.variables))


        lamb_v_grads = tape.gradient(loss_result['lamb_v_loss'], [self._lamb_v])
        self._optimizers['lamb_v'].apply_gradients(zip(lamb_v_grads, [self._lamb_v]))
        

        if self._bc_loss=='likelihood':
            policy_grads = tape.gradient(loss_result['policy_loss_new'], self._policy_network.variables)
            tf.cond(self._iteration >= self._warmup_iterations and self._bc_loss=='likelihood',
                    lambda: self._optimizers['policy'].apply_gradients(zip(policy_grads, self._policy_network.variables)),
                    lambda: tf.no_op())

            ent_coeff_grads = tape.gradient(loss_result['ent_coeff_loss'], [self._log_ent_coeff])
            tf.cond(self._iteration >= self._warmup_iterations and self._bc_loss=='likelihood',
                    lambda: self._optimizers['ent_coeff'].apply_gradients(zip(ent_coeff_grads, [self._log_ent_coeff])),
                    lambda: tf.no_op())
        else:
            policy_grads = tape.gradient(loss_result['policy_loss_mse'], self._policy_network.variables)
            tf.cond(self._iteration >= self._warmup_iterations,
                    lambda: self._optimizers['policy'].apply_gradients(zip(policy_grads, self._policy_network.variables)),
                    lambda: tf.no_op())
            



        self._iteration.assign_add(1)

        return loss_result

    @tf.function
    def step(self, observation):
        """
        observation: batch_size x obs_dim
        """
        observation = tf.convert_to_tensor(observation, dtype=tf.float32)

        if self._bc_loss == 'likelihood':
            action = self._policy_network.deterministic_action((observation,)) 
        else:
            action, _ = self._policy_network((observation,))


        return action, None
    

        


def run(config):
    np.random.seed(config['seed'])
    tf.random.set_seed(config['seed'])

    # load dataset
    env = gym.make(config['env_name'])
    env.seed(config['seed'])
    initial_obs_dataset, dataset, dataset_statistics = util.dice_dataset(env, standardize_observation=True, absorbing_state=config['absorbing_state'], standardize_reward=config['standardize_reward'])
    if config['use_policy_entropy_constraint'] or config['use_data_policy_entropy_constraint']:
        if config['target_entropy'] is None:
            config['target_entropy'] = -np.prod(env.action_space.shape)

    print(f'observation space: {env.observation_space.shape}')
    print(f'- high: {env.observation_space.high}')
    print(f'- low: {env.observation_space.low}')
    print(f'action space: {env.action_space.shape}')
    print(f'- high: {env.action_space.high}')
    print(f'- low: {env.action_space.low}')

    def _sample_minibatch(batch_size, reward_scale):
        initial_indices = np.random.randint(0, dataset_statistics['N_initial_observations'], batch_size)
        indices = np.random.randint(0, dataset_statistics['N'], batch_size)
        sampled_dataset = (
            initial_obs_dataset['initial_observations'][initial_indices],
            dataset['observations'][indices],
            dataset['actions'][indices],
            dataset['rewards'][indices] * reward_scale,
            dataset['next_observations'][indices],
            dataset['terminals'][indices]
        )
        return tuple(map(tf.convert_to_tensor, sampled_dataset))

    # Create an agent
    agent = OSDBC(
        observation_spec=TensorSpec(
            (dataset_statistics['observation_dim'] + 1) if config['absorbing_state'] else dataset_statistics['observation_dim']),
        action_spec=TensorSpec(dataset_statistics['action_dim']),
        config=config
    )

    result_logs = []
    start_iteration = 0

    # Start training
    start_time = time.time()
    last_start_time = time.time()
    for iteration in tqdm(range(start_iteration, config['total_iterations'] + 1), ncols=70, desc='DICE', initial=start_iteration, total=config['total_iterations'] + 1, ascii=True, disable=os.environ.get("DISABLE_TQDM", False)):
        # Sample mini-batch data from dataset
        initial_observation, observation, action, reward, next_observation, terminal = _sample_minibatch(config['batch_size'], config['reward_scale'])

        # Perform gradient descent
        before = time.time()
        train_result = agent.train_step(initial_observation, observation, action, reward, next_observation, terminal)
        #print('iteration', iteration, 'delta time', time.time() - before)
        if iteration % config['log_iterations'] == 0:
            train_result = {k: v.numpy() for k, v in train_result.items()}
            if iteration >= config['warmup_iterations']:
                # evaluation via real-env rollout
                eval = util.evaluate(env, agent, dataset_statistics, absorbing_state=config['absorbing_state'], pid=config.get('pid'))
                train_result.update({'iteration': iteration, 'eval': eval})
            
            train_result.update({'iter_per_sec': config['log_iterations'] / (time.time() - last_start_time)})

            result_logs.append({'log': train_result, 'step': iteration})
            if agent._use_wandb:
                for k in train_result.keys():
                    wandb.log({k: train_result[k]}, step=iteration)  
                
            # tag =    int(iteration / config['log_iterations'])    
            # agent.save_weights(agent._v_network, 'models/v_'+str(tag)+'.pkl') 
            # agent.save_weights(agent._ws_network, 'models/ws_'+str(tag)+'.pkl') 
            # agent.save_weights(agent._lamb_v, 'models/lamb_v_'+str(tag)+'.pkl') 
            
            if not int(os.environ.get('DISABLE_STDOUT', 0)):
                print(f'=======================================================')
                for k, v in sorted(train_result.items()):
                    print(f'- {k:23s}:{v:15.10f}')
                if train_result.get('eval'):
                    print(f'- {"eval":23s}:{train_result["eval"]:15.10f}')
                print(f'config={config}')
                print(f'iteration={iteration} (elapsed_time={time.time() - start_time:.2f}s, {train_result["iter_per_sec"]:.2f}it/s)')
                print(f'=======================================================', flush=True)

            last_start_time = time.time()

    wandb.finish()
if __name__ == "__main__":
    from default_config import get_parser
    args = get_parser().parse_args()
    run(vars(args))
