from pathlib import Path
import d4rl
import gym
from load_dataset import dataset_from_env
import yaml
from utils import TrainLogger, redirect_stdout
import os
import json
import sys
from actor_critic import DDPG, TD3, SAC, PPO
from collect_trajectory import load_agent
from rlhf_training import baseline_training, behavior_regularized_training, soft_behavior_regularized_training, uncertainty_training, load_stochastic_behavior_model, baseline_training_ppo, soft_behavior_regularized_training_ppo
from pydantic import BaseModel, ConfigDict, field_validator, model_validator, ValidationError
from cql_training import cql_training
import fire
import numpy as np
import pyrallis


def load_config_cli(**kwargs):
    try:
        config = Config(**kwargs)
        return config
    except ValidationError as e:
        raise ValidationError(f"Validation error: {e}")
    

class Config(BaseModel):
    #General PBRL
    clip_len: int = 20
    data_num: int = None
    
    #ORL PBRL
    bin_label_allow_overlap: int = 1
    bin_label_trajectory_batch: int = 0
    num_berno: int = 1
    quick_stop: int = 0
    dataset_size_multiplier: float = 1.0
    reuse_fraction: float = 0.0
    reuse_times:int = 0
    
    
    #General Experiment
    alg: str = 'sac'
    env: str = "halfcheetah-medium-v2"
    learn: str = 'behavior'
    device: str = "cuda"
    seed: int = 0
    sub: bool = False
    
    #Agent
    gamma: float = 0.99
    
    #Reward + Behavior Model Training
    hid: int = 64
    l: int = 3
    
    #Step 2 Learning
    buffer_size: int = 4_000_000  # Replay buffer size
    batch_size: int = 100  # Batch size for all networks
    regu: float = 0.1
    epochs: int = 150
    steps_per_epoch: int = 1000
    n_episodes: int = 10  # How many episodes run during evaluation
    max_ep_len: int = 1000
    
    # CQL
    discount: float = 0.99  # Discount factor
    alpha_multiplier: float = 1.0  # Multiplier for alpha in loss
    use_automatic_entropy_tuning: bool = True  # Tune entropy
    backup_entropy: bool = False  # Use backup entropy
    policy_lr: float = 3e-5  # Policy learning rate
    qf_lr: float = 3e-4  # Critics learning rate
    soft_target_update_rate: float = 5e-3  # Target network update rate
    target_update_period: int = 1  # Frequency of target nets updates
    cql_n_actions: int = 10  # Number of sampled actions
    cql_importance_sample: bool = True  # Use importance sampling
    cql_lagrange: bool = False  # Use Lagrange version of CQL
    cql_target_action_gap: float = -1.0  # Action gap
    cql_temp: float = 1.0  # CQL temperature
    cql_alpha: float = 10.0  # Minimal Q weight
    cql_max_target_backup: bool = False  # Use max target backup
    cql_clip_diff_min: float = -np.inf  # Q-function lower loss clipping
    cql_clip_diff_max: float = np.inf  # Q-function upper loss clipping
    orthogonal_init: bool = True  # Orthogonal initialization
    normalize: bool = True  # Normalize states
    normalize_reward: bool = False  # Normalize reward
    q_n_hidden_layers: int = 3  # Number of hidden layers in Q networks
    reward_scale: float = 1.0  # Reward scale for normalization
    reward_bias: float = 0.0  # Reward bias for normalization

    # AntMaze hacks
    bc_steps: int = int(0)  # Number of BC steps at start
    reward_scale: float = 5.0
    reward_bias: float = -1.0
    policy_log_std_multiplier: float = 1.0
    
    # Logging
    use_wandb: bool = True
    project: str = 'PBRL'
    name: str = 'initalsetup'
    print_logs: bool = True

TRAIN_TYPE_MAP = {
    'baseline': baseline_training, 
    'uncertainty': uncertainty_training,
    'behavior': behavior_regularized_training, 
    'brac': soft_behavior_regularized_training, 
    'naive-behavior': baseline_training, 
    'naive-behavior-ppo': baseline_training_ppo,
    'naive-behavior-ppo-true-reward': baseline_training_ppo,
    'brac-ppo': soft_behavior_regularized_training_ppo, 
    'naive-true-reward': baseline_training
}

    
if __name__ == '__main__':
    config = fire.Fire(load_config_cli)
    if config.learn == 'test':
        agent = eval(config.alg.upper())(env_name=config.env, ac_kwargs=dict(hidden_sizes=[256] * 2),
                                       gamma=config.gamma, num_test_episodes=30)
        if 'expert' in config.env:
            load_agent(agent,'exp')
        elif 'medium' in config.env:
            load_agent(agent,'med')

        rets = agent.test_agent()
        pfm = sum(rets) / len(rets)
        print('policy performance:', pfm)

        sys.exit()
    
    if not config.sub:
        res_path = Path('../result/%s_%s/outputs_%d.json' % (config.learn, config.env, config.seed))
    else:
        res_path = Path('../result_%s/%s_%s/outputs_%d.json' % (config.data_num, config.learn, config.env, config.seed))
    
    res_path.parent.mkdir(parents = True, exist_ok= True)
    
    redirect_stdout(open(os.path.join(str(res_path.parent),'log_%d' %(config.seed)), 'w'))
    logger = TrainLogger.from_config(config)
    
    if config.learn in TRAIN_TYPE_MAP:
        dataset = dataset_from_env(config)
        ac_kwargs= dict(hidden_sizes=[config.hid] * config.l)
        if config.learn in ('naive-behavior', 'naive-behavior-ppo', 'naive-behavior-ppo-true-reward', 'brac-ppo'):
            ac_kwargs['behavior_model'] = load_stochastic_behavior_model(config.env, config.data_num if config.data_num is not None else 0, 
                                                                         config.seed)
        
        agent = eval(config.alg.upper())(env_name=config.env, ac_kwargs=ac_kwargs,
                                            gamma=config.gamma, terminate= True)
        
        kwargs = {'dataset': dataset,'agent': agent,'epochs': config.epochs, 'steps_per_epoch': config.steps_per_epoch, 
        'seed': config.seed, 'logger': logger, **({'regu': config.regu} if config.learn == 'behavior' else {})}
        
        if config.learn in ('naive-behavior-ppo-true-reward', 'naive-true-reward'):
            kwargs['true_reward'] = True
        
        true_pfms, sim_pfms = TRAIN_TYPE_MAP[config.learn](**kwargs)
    else:
        true_pfms, sim_pfms = cql_training(config, logger)

    with open(res_path, 'w') as f:
        f.write(json.dumps({'true_pfms': true_pfms, 'sim_pfms': sim_pfms }))