from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import random
import numpy as np
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
import tensorflow as tf
from tqdm import tqdm
import wrappers
import oil_igi
import utils
import time
import pickle
from numpy.linalg import inv

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

def evaluate_d4rl(env, actor, train_env_id, num_episodes):

    total_timesteps = 0
    total_returns = 0
    
    for _ in range(num_episodes):
        state = env.reset()
        done = False
        while not done:
            if 'ant' in train_env_id.lower():
                state = np.concatenate((state[:27], [0.]), -1)
            action = actor.step(state)[0].numpy()

            next_state, reward, done, _ = env.step(action)
            
            total_returns += reward
            total_timesteps += 1
            state = next_state
            
    mean_score = total_returns / num_episodes
    mean_timesteps = total_timesteps / num_episodes
    return mean_score, mean_timesteps

def run(config):
    seed = config['seed']
    tf.random.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    env_id = config['env_id']
    
    expert_dataset_name, expert_num_traj = config['expert_dataset_info'] 
    expert_num_traj = int(expert_num_traj)

    suboptimal_dataset_names = config['suboptimal_dataset_name'] 
    suboptimal_num_trajs = config['suboptimal_dataset_num']

    print(f'number of {suboptimal_dataset_names[0]} trajectories for suboptimal dataset: {suboptimal_num_trajs[0]}')
    print(f'number of {suboptimal_dataset_names[1]} trajectories for suboptimal dataset: {suboptimal_num_trajs[1]}\n')

    dataset_dir = config['dataset_dir']
    
    (expert_states, expert_actions, expert_next_states, expert_dones, expert_lens) = utils.load_d4rl_data(
        dataset_dir, env_id, expert_dataset_name, expert_num_traj, start_idx=0)
    
    suboptimal_states, suboptimal_actions, suboptimal_next_states, suboptimal_dones, suboptimal_lens = [], [], [], [], []
    if len(suboptimal_dataset_names) > 0:
        for suboptimal_datatype_idx, (suboptimal_dataset_name, suboptimal_num_traj) in enumerate(
                zip(suboptimal_dataset_names, suboptimal_num_trajs)):
            start_idx = expert_num_traj if (expert_dataset_name == suboptimal_dataset_name) else 0 
            
            (states, actions, next_states, dones, lens) = utils.load_d4rl_data(dataset_dir, env_id,
                                                                                         suboptimal_dataset_name,
                                                                                         suboptimal_num_traj,
                                                                                         start_idx=start_idx) 
            
            suboptimal_states.append(states)
            suboptimal_actions.append(actions)
            suboptimal_next_states.append(next_states)
            suboptimal_dones.append(dones)
            suboptimal_lens.append(lens)
                   
    
    suboptimal_states = np.concatenate(suboptimal_states).astype(np.float32)
    suboptimal_actions = np.concatenate(suboptimal_actions).astype(np.float32)
    suboptimal_next_states = np.concatenate(suboptimal_next_states).astype(np.float32)
    suboptimal_dones = np.concatenate(suboptimal_dones).astype(np.float32)
    suboptimal_lens = np.concatenate(suboptimal_lens).astype(np.float32)
    
    total_states = np.concatenate([suboptimal_states, expert_states]).astype(np.float32)
    total_actions = np.concatenate([suboptimal_actions, expert_actions]).astype(np.float32)
    total_next_states = np.concatenate([suboptimal_next_states, expert_next_states]).astype(np.float32)
    total_dones = np.concatenate([suboptimal_dones, expert_dones]).astype(np.float32)
    total_lens = np.concatenate([suboptimal_lens, expert_lens]).astype(np.int64)
    

    print('number of expert demonstraions: {}'.format(expert_states.shape[0]))
    print('number of suboptimal demonstraions: {}\n'.format(suboptimal_states.shape[0]))
    
    shift = -np.mean(suboptimal_states, 0)
    scale = 1.0 / (np.std(suboptimal_states, 0) + 1e-3)
    
    expert_states = (expert_states + shift) * scale
    expert_next_states = (expert_next_states + shift) * scale
    total_states = (total_states + shift) * scale
    total_next_states = (total_next_states + shift) * scale
    
    if 'ant' in env_id.lower():
        shift_env = np.concatenate((shift, np.zeros(84)))
        scale_env = np.concatenate((scale, np.ones(84)))
        
    else:
        shift_env = shift 
        scale_env = scale 
    
    env = wrappers.create_il_env(env_id, seed, shift_env, scale_env, normalized_box_actions=False) 
    
    eval_env = wrappers.create_il_env(env_id, seed + 1, shift_env, scale_env, normalized_box_actions=False) 

    if config['using_absorbing']:
        (expert_states, expert_actions, expert_next_states,
         expert_dones, _) = utils.add_absorbing_states(expert_states, expert_actions, expert_next_states, expert_dones, expert_lens, env)
        (total_states, total_actions, total_next_states,
         total_dones, total_lens) = utils.add_absorbing_states(total_states, total_actions, total_next_states, total_dones, total_lens, env)
    else:
        expert_states = np.c_[expert_states, np.zeros(len(expert_states), dtype=np.float32)]
        expert_next_states = np.c_[expert_next_states, np.zeros(len(expert_next_states), dtype=np.float32)]
        total_states = np.c_[total_states, np.zeros(len(total_states), dtype=np.float32)]
        total_next_states = np.c_[total_next_states, np.zeros(len(total_next_states), dtype=np.float32)]


    if 'ant' in env_id.lower():
        observation_dim = 28
    else:
        observation_dim = env.observation_space.shape[0]
        
    # Create imitator
    action_dim = env.action_space.shape[0]
    imitator = oil_igi.DICE(
        observation_dim,
        action_dim,
        config=config,
        )
        
    print("Save interval :", config['save_interval'])
    
    # checkpoint dir
    checkpoint_dir = f"checkpoint/{config['env_id']}/" \
                     f"{config['expert_dataset_info'][0]}_{config['expert_dataset_info'][1]}_" \
                     f"{config['suboptimal_dataset_name']}_{config['suboptimal_dataset_num']}/{config['save_folder_name']}"
    os.makedirs(checkpoint_dir, exist_ok=True)

    checkpoint_filepath = f"{checkpoint_dir}/{config['gamma']}_{config['seed']}.pickle"

    if config['resume'] and os.path.exists(checkpoint_filepath):
        # Load checkpoint.
        checkpoint_data = imitator.load(checkpoint_filepath)
        training_info = checkpoint_data['training_info']
        training_info['iteration'] += 1
        print(f"Checkpoint '{checkpoint_filepath}' is resumed")
    else:
        print(f"No checkpoint is found: {checkpoint_filepath}")
        training_info = {
            'iteration': 0,
            'logs': [],
        }  
    
    config['total_iterations'] = config['total_iterations'] + 1
    
    
    # IGI
    max_timestep = np.max(total_lens) - 1
    
    num_of_t = np.zeros(max_timestep + 1, np.float32)
    for lens in total_lens:
        num_of_t[:lens] += 1
    P_T = num_of_t / np.sum(num_of_t)

    geom =[]
    for i in range(max_timestep+1):
        geom.append((1-config['gamma'])*(config['gamma']**i))
    geom = geom / np.sum(geom)
    geom = np.array(geom, np.float32)
    
    initial_prob = utils.IGI(num_of_t, config['gamma'], max_timestep, geom)
   
    sum_ = 0
    sums = []
    for i in range(len(total_lens)):
        sums.append(sum_)
        sum_ += total_lens[i]
    sums = np.array(sums) # To compute indices of certain timestep easily

    # Start training
    start_time = time.time()
    with tqdm(total=config['total_iterations'], initial=training_info['iteration'], desc='',
              disable=os.environ.get("DISABLE_TQDM", False), ncols=70) as pbar:
        while training_info['iteration'] < config['total_iterations']:    
            
            init_timestep = tf.random.categorical(tf.math.log([initial_prob]), config['batch_size'])[0]
            total_timestep = tf.random.categorical(tf.math.log([P_T]), config['batch_size'])[0]
            
            init_indices = utils.sampling_indices_list(total_lens, init_timestep, sums)
            total_indices = utils.sampling_indices_list(total_lens, total_timestep, sums)
            
            # We just use uniform sampling for expert data since number of expert trajectory is 1
            expert_indices = np.random.randint(0, len(expert_states), size=config['batch_size'])

            info_dict = imitator.update(
                total_states[init_indices],
                expert_states[expert_indices],
                expert_actions[expert_indices],
                expert_next_states[expert_indices],
                total_states[total_indices],
                total_actions[total_indices],
                total_next_states[total_indices],
            )
                
            if training_info['iteration'] % config['log_interval'] == 0:

                average_returns, evaluation_timesteps = evaluate_d4rl(eval_env, imitator, env_id, 10)
                
                info_dict.update({'eval': average_returns})
                print(f'Eval: ave returns=d: {average_returns}'
                      f' ave episode length={evaluation_timesteps}'
                      f' / elapsed_time={time.time() - start_time} ({training_info["iteration"] / (time.time() - start_time)} it/sec)')
                print('=========================')
                for key, val in info_dict.items():
                    print(f'{key:25}: {val:8.3f}')
                print('=========================')
                
                training_info['logs'].append({'step': training_info['iteration'], 'log': info_dict})
                print(f'timestep {training_info["iteration"]} - log update...')
                print('Done!', flush=True)
                
            # Save checkpoint
            if training_info['iteration'] % config['save_interval'] == 0 and training_info['iteration'] > 0:
                imitator.save(checkpoint_filepath, training_info)
            training_info['iteration'] += 1
            pbar.update(1)
            
    print('last procedure, evaluating for 100 episodes...')
    last_evaluation, _ = evaluate_d4rl(eval_env, imitator, env_id, 100)
    training_info['last_evaluation'] = last_evaluation
    imitator.save(checkpoint_filepath, training_info)
    print(f'Done! Your final evaluation is {last_evaluation}')
    

if __name__ == "__main__":
    from configuration import get_parser
    
    print('=================latest===================')
    args = get_parser().parse_args()
    config = vars(args)
         
    print("Start DICE running with using IGI\n")
    print(f"discount factor: {config['gamma']}")
    run(config)
