import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
from jax import config

import sys
sys.path.append('..')
sys.path.append('./')

import numpy as np
import sys
from absl import app, flags
from parse import args
import jax
from collections import deque
import random
import wandb
import gym
from pprint import pprint
import d4rl

from sources.utils import ConfigArgs, evaluate, log_wandb
from sources.algos.UNIQ.algo import UNIQ
from sources.dataset.mix_dataset import CombinedDataset
from sources.dataset.d4rl_dataset import get_d4rl_dataset
from sources.utils.env_wrappers import EpisodeMonitor, SinglePrecision


def make_mixed_dataset(args):
    """Creates a combined dataset from multiple sources with good/bad labels
    Args:
        args: Configuration arguments containing dataset names and sizes
    Returns:
        mixed_dataset: Combined dataset object
        mix_dataset_name: String identifier for the mixed dataset
    """
    dataset_ls = []
    is_good_ls = []
    is_bad_ls = []
    mix_dataset_name = ''
    for i,(mixed_name, mixed_size) in enumerate(zip(args.mixed_name_list, args.mixed_size_list)):
        dataset_ls.append(get_d4rl_dataset(mixed_name, int(mixed_size)))
        is_good_ls.append(args.is_good_list[i])
        is_bad_ls.append(args.is_bad_list[i])
        print(f'Loaded {mixed_name} with size {mixed_size}, shape {dataset_ls[-1].observations.shape}, ',
              f'is_good: {is_good_ls[-1]}, is_bad: {is_bad_ls[-1]}')
        task_name = mixed_name.split('-')[1]
        mix_dataset_name += f'{task_name}-{int(mixed_size)//1000}k,'
        
    mixed_dataset = CombinedDataset(dataset_ls, is_good_ls, is_bad_ls)

    print('-'*100)
    print('mixed_dataset')
    print('Total samples: ', mixed_dataset.observations.shape)
    print('Good samples: ', mixed_dataset.observations[mixed_dataset.is_good==1].shape)
    print('Bad samples: ', mixed_dataset.observations[mixed_dataset.is_bad==1].shape)
    print('-'*100)
    return mixed_dataset, mix_dataset_name

def create_expert_dataset_and_env(args):
    """Sets up the environment and loads expert demonstration data
    Args:
        args: Configuration arguments containing environment settings
    Returns:
        env: Configured gym environment
        expert_dataset: Dataset containing expert demonstrations
    """
    env = gym.make(args.env_name)
    env = EpisodeMonitor(env)
    env = SinglePrecision(env)
    env.seed(args.seed)
    env.action_space.seed(args.seed)
    env.observation_space.seed(args.seed)
    
    expert_dataset = get_d4rl_dataset(args.env_name, 1000)
    return env, expert_dataset

def make_bad_dataset(args):
    """Creates a dataset of suboptimal/bad demonstrations
    Args:
        args: Configuration arguments containing bad dataset specifications
    Returns:
        bad_dataset: Combined dataset of bad demonstrations
        bad_dataset_name: String identifier for the bad dataset
    """
    dataset_ls = []
    is_good_ls = []
    is_bad_ls = []
    bad_dataset_name = f''

    for i,(bad_name, bad_size) in enumerate(zip(args.bad_name_list, args.bad_size_list)):
        dataset_ls.append(get_d4rl_dataset(bad_name, int(bad_size), start_idx=900000))
        is_good_ls.append(0)
        is_bad_ls.append(1)
        print(f'Loaded {bad_name} with size {bad_size}, shape {dataset_ls[-1].observations.shape}, ',
              f'is_good: {is_good_ls[-1]}, is_bad: {is_bad_ls[-1]}')
        task_name = bad_name.split('-')[1]
        bad_dataset_name += f'{task_name}-{int(bad_size)//1000}k,'
        
    bad_dataset = CombinedDataset(dataset_ls, is_good_ls, is_bad_ls)

    print('-'*100)
    print('bad_dataset')
    print('Total samples: ', bad_dataset.observations.shape)
    print('Good samples: ', bad_dataset.observations[bad_dataset.is_good==1].shape)
    print('Bad samples: ', bad_dataset.observations[bad_dataset.is_bad==1].shape)
    print('-'*100)
    return bad_dataset, bad_dataset_name

def main(_):
    """Main training loop for the UNIQ algorithm
    
    Key steps:
    1. Setup environment and random seeds
    2. Load and process datasets (bad and mixed)
    3. Initialize state normalization if enabled
    4. Setup wandb logging
    5. Initialize UNIQ agent
    6. Train/load discriminator
    7. Main training loop with periodic evaluation
    """
    # Setup paths and random seeds and run name
    args.save_dir = os.path.join(args.save_dir, args.exp_name)
    random.seed(args.seed)
    np.random.seed(args.seed)
    
    robot_name = args.env_name.split('-')[0]
    run_name = f'[Bad'
    for i in range(len(args.bad_name_list)):
        run_name += f'|{args.bad_name_list[i][0]}={int(args.bad_size_list[i])//1000}k'
        args.bad_name_list[i] = f'{robot_name}-{args.bad_name_list[i]}-v2'

    run_name += ']_[Mix'
    for i in range(len(args.mixed_name_list)):
        run_name += f'|{args.mixed_name_list[i][0]}={int(args.mixed_size_list[i])//1000}k'
        args.mixed_name_list[i] = f'{robot_name}-{args.mixed_name_list[i]}-v2'
    run_name += ']'       
    args.save_dir = os.path.join(args.save_dir,robot_name, run_name)
    
    
    # Load datasets
    env, _ = create_expert_dataset_and_env(args)
    bad_dataset, _ = make_bad_dataset(args)
    print('bad_dataset')
    print(f'Loaded bad dataset with size {bad_dataset.observations.shape}')
    print('-'*100)
    mixed_dataset, _ = make_mixed_dataset(args)
    
    # Normalize states
    if (args.state_norm):
        print('Normalizing states')
        shift = - mixed_dataset.observations.mean(axis=0)
        scale = 1 / (mixed_dataset.observations.std(axis=0) + 1e-3)
    else:
        print('Not normalizing states')
        shift = 0
        scale = 1

    # Setup wandb
    algo_name = 'UNIQ'
    if (args.update_Q_inference):
        algo_name += '_Q'

    if (args.use_wandb):
        raise NotImplementedError('Init wandb here')
    
    # Print run name and args
    print('-'*100)
    print(run_name)
    print('args')
    pprint(args.flag_values_dict())
    print('-'*100)
    
    # Create agent arguments
    agent_args = ConfigArgs(f = args.f,
                    sample_random_times=args.sample_random_times,
                    grad_pen=args.grad_pen,
                    lambda_gp=args.lambda_gp,
                    noise=args.noise,
                    max_clip=args.max_clip,
                    alpha=args.alpha,
                    num_v_updates=args.num_v_updates,
                    log_loss=args.log_loss,
                    noise_std=args.noise_std,
                    eval_interval=args.eval_interval,
                    v_update=args.v_update,
                    clip_threshold=args.clip_threshold,
                    update_Q_inference=args.update_Q_inference)
    
    hidden_dims = tuple([args.hidden_size]*args.num_layers)
    print(f'hidden_dims: {hidden_dims}')
    
    # Initialize UNIQ agent
    agent = UNIQ(args.seed,
                observations=env.observation_space.sample()[np.newaxis],
                actions=env.action_space.sample()[np.newaxis],
                max_steps=args.max_steps,
                double_q=args.double,
                actor_lr=args.actor_lr,
                critic_lr=args.critic_lr,
                disc_lr=args.disc_lr,
                value_lr=args.value_lr,
                hidden_dims=hidden_dims,
                discount=args.discount,
                expectile=args.expectile,
                actor_temperature=args.actor_temperature,
                dropout_rate=args.dropout_rate,
                layernorm=args.layernorm,
                tau=args.tau,
                reward_gap=args.reward_gap,
                weight_decay=args.weight_decay,
                args=agent_args)

    # Test discriminator
    def test_discriminator(agent, bad_dataset, mixed_dataset, shift, scale):
        """Tests discriminator performance on bad and mixed datasets
        Args:
            agent: UNIQ agent containing the discriminator
            bad_dataset: Dataset of bad demonstrations
            mixed_dataset: Dataset of mixed demonstrations
            shift: State normalization shift
            scale: State normalization scale
        Returns:
            mul_ratio: Ratio multiplier for balancing discriminator outputs
        """
        print('test discriminator ratio')
        bad_batch = bad_dataset.sample(batch_size=10000, shift=shift, scale=scale)
        mix_batch = mixed_dataset.sample(batch_size=10000, shift=shift, scale=scale)
        visible_bad_ratio = agent.bad_disc(bad_batch.observations,bad_batch.actions)/agent.mix_disc(bad_batch.observations,bad_batch.actions)
        mix_ratio = agent.bad_disc(mix_batch.observations,mix_batch.actions)/agent.mix_disc(mix_batch.observations,mix_batch.actions)
        
        mul_ratio = visible_bad_ratio.mean()/((mix_ratio*mix_batch.is_bad).sum()/mix_batch.is_bad.sum())
        mix_ratio *= mul_ratio
        
        ratio_info = {
            'final_ratio/visible_bad_ratio': round(visible_bad_ratio.mean().item(), 3),
            'final_ratio/mix_ratio': round(mix_ratio.mean().item(), 3),
            'final_ratio/hidden_bad_ratio': round(((mix_ratio*mix_batch.is_bad).sum()/mix_batch.is_bad.sum()).item(), 3),
            'final_ratio/hidden_good_ratio': round(((mix_ratio*(1-mix_batch.is_bad)).sum()/(1-mix_batch.is_bad).sum()).item(), 3),
            'final_ratio/mul_ratio': round(mul_ratio.item(), 1),
        }
        pprint(ratio_info)
        print('-'*50)
        return mul_ratio
        
    # Load discriminator
    disc_load_dir = os.path.join(args.save_dir, f's({args.seed})_Snorm({args.state_norm})_noise({args.noise_std})'+
                        f'_hidden({args.hidden_size})_layers({args.num_layers})_step({args.num_disc_train//1000}k)')
    if (not agent.load_discriminator(disc_load_dir)):
        print(f'No discriminator found in {disc_load_dir}, training from scratch')
        # Train discriminator
        for step in range(args.num_disc_train):
            if (not (agent.train_bad or agent.train_mix)):
                print(f'early stopping at step {step}')
                break
            
            bad_batch = bad_dataset.sample(batch_size=5000, shift=shift, scale=scale)
            mix_batch = mixed_dataset.sample(batch_size=5000, shift=shift, scale=scale)
            info = agent.update_discriminator(bad_batch, mix_batch, step=step)
            if (step % args.eval_interval == 0):
                print('-'*50)
                for k, v in info.items():
                    if isinstance(v, jax.Array):
                        info[k] = round(v.item(), 3)
                    elif isinstance(v, float):
                        info[k] = round(v, 3)
                pprint(info)
                test_discriminator(agent, bad_dataset, mixed_dataset, shift, scale)
                
        print('-'*50)
        print('finished training discriminator')
        print('save discriminator')
        agent.save_discriminator(disc_load_dir)
    else:
        print(f'Loaded discriminator from {disc_load_dir}')
        
    print('-'*50)
    
    args.scale_mix = test_discriminator(agent, bad_dataset, mixed_dataset, shift, scale)
    args.scale_mix = 1.0
    print(f'scale_mix: {args.scale_mix}')
    best_eval_returns = -np.inf
    last_20_returns = deque(maxlen=20)
    print('------- start training -------')
    
    # Main training loop
    for step in range(0, args.max_steps + 1): 
        bad_batch = bad_dataset.sample(batch_size=args.batch_size, shift=shift, scale=scale)
        mixed_batch = mixed_dataset.sample(batch_size=args.batch_size, shift=shift, scale=scale)
        update_info = agent.update(bad_batch=bad_batch, mixed_batch=mixed_batch,
                                   scale_mix=args.scale_mix, step=step)
        
        if (step % args.eval_interval == 0):            
            eval_stats = evaluate(agent, env, args.eval_episodes,
                                  shift=shift, scale=scale)
            last_20_returns.append(eval_stats['return'])
            print(f"Eval in step {step} return: {eval_stats['return']:.2f}")
            if (eval_stats['return'] > best_eval_returns):
                best_eval_returns = eval_stats['return']
                            
            update_info['eval/return'] = eval_stats['return']
            update_info['eval/last_20_return'] = np.mean(last_20_returns)
            update_info['eval/last_20_return_std'] = np.std(last_20_returns)
            update_info['eval/last_20_return_min'] = np.min(last_20_returns)
            update_info['eval/last_20_return_max'] = np.max(last_20_returns)
            update_info['eval/last_20_return_stderr'] = np.std(last_20_returns) / np.sqrt(len(last_20_returns))
            update_info['eval/best_return'] = best_eval_returns
            for k, v in update_info.items():
                if isinstance(v, jax.Array):
                    update_info[k] = round(v.item(), 3)
                elif isinstance(v, float):
                    update_info[k] = round(v, 3)
            log_wandb(update_info, args, step)
            
    

if __name__ == '__main__':
    app.run(main)
