import os
import time

import gym
import numpy as np
import time

import torch
from tqdm import tqdm
import cluster
import wandb
import os
os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE' 


np.set_printoptions(precision=3, suppress=True)

from src.smodice_pytorch import SMODICE
from src.doi_clean import DOI as DOIClean
from src.doi_clean import DOISingleStep as DOICleanSingleStep

import src.utils as utils
from src.utils import configure_wandb
from src.utils import red_print, on_cluster
import logging

np.set_printoptions(precision=3, suppress=True)

MUJOCO = ['hopper', 'walker2d', 'halfcheetah', 'ant']


def run(config):
    ########################################
    # Setup
    
    # load checkpoints if existing
    ckpt, pretrain_ckpt = utils.maybe_load_last_checkpoint()
    configure_wandb(config, project="smodice-doi", entity='al-group')
    if config.debug:
        red_print('##### DEBUGGING MODE #####')
    # make sure we are not debugging on cluster
    assert not (config.debug and on_cluster())
    
    
    ############### ENVIRONMENTS AND DATA ###########################
    
    version = 'v2'
    offline_dataset, expert_dataset = None, None
    actions_max, actions_min = None, None
    if config.env_name != 'solo12':
        if 'kitchen' in config['env_name']:
            version  = 'v0'
        # Load environment
        if not config['mismatch']:
            env = gym.make(f"{config['env_name']}-{config['dataset']}-{version}")
        else:      
            env = gym.make(f"{config['env_name']}-random-{version}")

        if config['env_name'] not in MUJOCO:
            if config['env_name'] != 'kitchen':
                expert_env = gym.make(f"{config['env_name']}-{config['dataset']}-{version}")
            else:
                expert_env = gym.make(f"kitchen-complete-v0")
        else:
            print('Making env expert...',f"{config['env_name']}-expert-{version}" )
            expert_env = gym.make(f"{config['env_name']}-expert-{version}")
        # Seeding
        np.random.seed(config['seed'])
        torch.manual_seed(config['seed']) 
        env.seed(config['seed'])
        expert_env.seed(config['seed'])
    elif config.env_name == 'solo12':
        expert_env, expert_dataset, offline_dataset, actions_max, actions_min = utils.get_solo_env_and_data(config)
        env = expert_env
    else:
        raise NotImplementedError



    print(f'Loading expert dataset...')
    traj_iterator = utils.sequence_dataset(expert_env, dataset=expert_dataset)
    expert_traj = next(traj_iterator)

    # Load offline dataset
    if config['num_expert_traj'] == 0:
        initial_obs_dataset, dataset, dataset_statistics = utils.dice_dataset(env, standardize_observation=config['standardize_obs'], absorbing_state=config['absorbing_state'], 
                                                                              standardize_reward=config['standardize_reward'], dataset=offline_dataset)
    else:
        initial_obs_dataset, dataset, dataset_statistics = utils.dice_combined_dataset(expert_env, env, num_expert_traj=config['num_expert_traj'], 
                                                                                       num_offline_traj=config['num_offline_traj'],
    standardize_observation=config['standardize_obs'], absorbing_state=config['absorbing_state'],
     standardize_reward=config['standardize_reward'], offline_dataset=offline_dataset, expert_dataset=expert_dataset)

    # Normalize expert observations and potentially add absorbing state
    if config['standardize_obs']:
        expert_obs_dim = expert_traj['observations'].shape[1]
        expert_traj['observations'] = (expert_traj['observations'] - dataset_statistics['observation_mean'][:expert_obs_dim]) / (dataset_statistics['observation_std'][:expert_obs_dim] + 1e-10)
        if 'next_observations' in expert_traj:
            expert_traj['next_observations'] = (expert_traj['next_observations'] - dataset_statistics['observation_mean']) / (dataset_statistics['observation_std'] + 1e-10)
    if config['absorbing_state']:
        expert_traj = utils.add_absorbing_state(expert_traj)
    if config['use_policy_entropy_constraint'] or config['use_data_policy_entropy_constraint']:
        if config['target_entropy'] is None:
            config['target_entropy'] = -np.prod(env.action_space.shape)
    
    
    # Create inputs for the discriminator
    state_dim = dataset_statistics['observation_dim'] + 1 if config['absorbing_state'] else dataset_statistics['observation_dim']
    action_dim = 0 if config['state'] else dataset_statistics['action_dim']

    # Policy, value and classifier see disc_idxs, discriminator sees div_idxs
    disc_idxs = list(range(state_dim)) if config['disc_idxs'] == 'all' else config['disc_idxs']
    div_idxs = list(range(state_dim)) if config['div_idxs'] == 'all' else config['div_idxs']

    ########################################################
    ################## CREATE AGENT ########################
    ########################################################
    
    print(f'Creating {config["algo_type"]} agent!')
    if 'doi-clean' == config['algo_type']:
        agent = DOIClean(
            observation_spec=dataset_statistics['observation_dim'] + 1 if config['absorbing_state'] else dataset_statistics['observation_dim'],
            disc_idxs=disc_idxs, div_idxs=div_idxs, # what to use for policy and discriminator
            action_spec=dataset_statistics['action_dim'], behavior_policy=None,
            config=config
        )
    elif 'doi-clean-single' == config['algo_type']:
        agent = DOICleanSingleStep(
            observation_spec=dataset_statistics['observation_dim'] + 1 if config['absorbing_state'] else dataset_statistics['observation_dim'],
            disc_idxs=disc_idxs, div_idxs=div_idxs, # what to use for policy and discriminator
            action_spec=dataset_statistics['action_dim'], behavior_policy=None,
            config=config
        )
    else:
        raise Exception('Agent not available!')

    #########################################################
    ##################### PRETRAINING #######################
    #########################################################


    smodice_ckpt = torch.load(f'/is/rg/al/Data/doi/{config.env_name}_{config.dataset}_{config.num_expert_traj}_smodice_last.pth.tar')
    
    classifier_rewards = smodice_ckpt['classifier_rewards'].reshape(-1,1)
    expert_w_e = smodice_ckpt['w_e'].reshape(-1,1)
    expert_e_v = smodice_ckpt['e_v'].reshape(-1,1)

    assert len(classifier_rewards) == len(expert_w_e) == len(expert_e_v)
    dataset['classifier_rewards'] = classifier_rewards
    dataset['expert_w_e'] = expert_w_e.cpu()
    dataset['expert_e_v'] = expert_e_v.cpu()
    
    random_loader = utils.SmodiceLoader(config.batch_size, 
                                              dataset=dataset,
                                              initial_obs_dataset=initial_obs_dataset,
                                              dataset_statistics=dataset_statistics,
                                              shuffle=True)
    agent.expert_w_e = expert_w_e.T # for internal broadcasting
    agent.expert_e_v = expert_e_v.T


    batch_iterator = iter(random_loader)
    start_iteration = 0
    last_start_time = time.time()
    start_time = time.time()


    for iteration in tqdm(range(start_iteration, config['total_iterations'] + 1), ncols=70, desc='DOI', initial=start_iteration, total=config['total_iterations'] + 1, ascii=True, disable=os.environ.get("DISABLE_TQDM", False)):
        

        try:
            batch = next(batch_iterator)
        except StopIteration:
            batch_iterator = iter(random_loader)
            batch = next(batch_iterator)


        # Perform gradient descent
        max_steps = 280 if 'kitchen' in config['env_name'] else None 
        

        if config.use_gt:
            batch['classifier_rewards'] = torch.clamp(torch.log(batch['experts']) - torch.log(1. - batch['experts']), -2, 2)

        # data loader is passed because we have per-stage training
        train_result = agent.train_step(batch, data_loader=random_loader,iteration=iteration)            
        
        # log with wandb
        if iteration % config.wandb_log_interval == 0:
            log_dict = { f"train/{k}": v for k,v in  train_result.items()}
            log_dict['train_iter'] = iteration
            wandb.log(log_dict)
            if config.use_al_logger:
                for k,v in train_result.items():
                    if v.ndim > 0:
                        continue
                    if isinstance(v, torch.Tensor):
                        v = v.cpu().detach().numpy()
                    if (hasattr(v, 'ndim') and v.ndim == 0 and np.isfinite(v)) or np.isscalar(v): # skip non-scalar values
                        print('logging', k, v)
                        k = k.replace('/', '-')
                        logger.log(v, k, to_hdf=True)
                    

        # Logging
        if iteration % config['log_iterations'] == 0 or iteration == config['total_iterations']-1:
            train_result = {k: v.cpu().detach().numpy() for k, v in train_result.items() if isinstance(v, torch.Tensor)}
            # evaluation via real-env rollout
            eval, eval_std, info, obs_per_skill, obs_per_skill_agent = utils.evaluate_skills(env, agent, dataset_statistics, 
                                                                                                proj=False,
                                                                                                solo=config.env_name=='solo12',
                                                                                                absorbing_state=config['absorbing_state'], 
                                                                                            iteration=iteration, max_steps=max_steps, make_gif=config.log_video,
                                                                                                num_evaluation=config['num_evaluation'],
                                                                                                action_max=actions_max, action_min=actions_min,
                                                                                                disc_idxs=disc_idxs,
                                                                                                div_idxs=div_idxs)
 
            succ_dist = info['successor_eval/proj_succ_min_dist'] if 'proj_succ_min_dist' in info  else info['successor_eval/succ_min_dist']
            train_result.update({'iteration': iteration, 'eval/mu': eval, 'eval/std': eval_std, 'eval/min_succ_dist': succ_dist})
            
            ckpt = {
                    'iteration': iteration, 
                    'state_dict': agent.get_state_dict(),
                    'logger':  dict(allogger.get_logger("root").step_per_key),
                    'iteration': iteration,
                    'dataset_statistics': dataset_statistics,
                    'disc_idxs': disc_idxs,
                    'div_idxs': div_idxs
                    }
            
            utils.save_checkpoint(ckpt)

            if Ticker.should_restart():
                wandb.alert(title='Restart', text='Restarting the job at iteration {}'.format(iteration))
                wandb.finish(0)
                cluster.exit_for_resume()
                    
     
            expert_index = (batch['experts']==1).nonzero(as_tuple=False)
            offline_index = (batch['experts']==0).nonzero(as_tuple=False)
            if 'w_e' in train_result:
                w_e = train_result['w_e']
                w_e_expert = w_e[expert_index.cpu()].mean()
                w_e_offline = w_e[offline_index.cpu()].mean()
                w_e_ratio = w_e_expert / w_e_offline
                w_e_overall = w_e.mean()
                    
                
                train_result.update({'w_e': w_e_overall, 'w_e_expert': w_e_expert, 'w_e_offline': w_e_offline, 'w_e_ratio': w_e_ratio})
            train_result.update({'iter_per_sec': config['log_iterations'] / (time.time() - last_start_time)})
            last_start_time = time.time()
            wandb.log(train_result)
            
            if not int(os.environ.get('DISABLE_STDOUT', 0)):
                print(f'=======================================================')

                if train_result.get('eval'):
                    print(f'- {"eval":23s}:{train_result["eval"]:15.10f}')
                print(f'iteration={iteration} (elapsed_time={time.time() - start_time:.2f}s, {train_result["iter_per_sec"]:.2f}it/s)')
                print(f'=======================================================', flush=True)

            last_start_time = time.time()
            start_time = time.time()
            allogger.get_root().flush(children=True)
            print(f'logger flushed in {time.time() - start_time:.2f}s')
    

    final_log = {f'final/{k}': v for k,v in train_result.items()}
    wandb.log(final_log)



    print('Training finished')
    return {'done':  1.0}

import smart_settings
from smodice.utils import Ticker
import sys
if __name__ == "__main__":

    # load 
    config = smart_settings.load(sys.argv[0])

    Ticker(config) # has global state
    Ticker.device = 'cuda'

    metrics = run(config)
    cluster.save_metrics_params(metrics, config)
    wandb.finish(0)