import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__))) # add the project root directory to sy
import time
import wandb
import jax
import jax.numpy as jnp
import numpy as np
import pdb

from flax.training import train_state
from flax.training import checkpoints

from arguments import get_args
from functools import partial
from typing import Any

from rl.RAPCPPO_utils import _rapcppo_update, _env_step
from env.env_list import get_env
from model.actorcritic import Policy_Network, Value_Network, Policy_Network_Discrete, Lagrange_Network, Phi_Network

from rl.plot_utils import calculate_consumption
from rl.utils import optimizer, tree_index1
from rl.gae import Transition_reach,calculate_gae2, calculate_gae_reach4, calculate_done, calculate_phi_targets, calculate_phi_targets_success

class TrainState(train_state.TrainState):
    mean: Any
    variance: Any
    count: Any

def train(env, env_params, config, rng):

    def _train(train_state_total, ent_gamma):

        train_state_policy, train_state_energy, train_state_reach, train_state_phi, rng = train_state_total

        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = jax.vmap(env.reset, in_axes=(0, None))(reset_rng, env_params)
        rng, _rng = jax.random.split(rng)
        runner_state = (train_state_policy, train_state_energy,
                        train_state_reach, train_state_phi, env_state, obsv, _rng)

        # COLLECT TRAJECTORY
        runner_state, traj_batch = jax.lax.scan(
            env_step, runner_state, None, config["NUM_STEPS"]
        )

        # CALCULATE ADVANTAGE
        (train_state_policy, train_state_energy, train_state_reach, train_state_phi,
         env_state, last_obs, rng) = runner_state

        last_val = train_state_energy.apply_fn(train_state_energy.params, last_obs)
        last_val_reach = train_state_reach.apply_fn(train_state_reach.params, last_obs)

        g_append = jnp.concatenate((traj_batch.g, jnp.expand_dims(env_state.g, axis=1).T))
        h_append = jnp.concatenate((traj_batch.h, jnp.expand_dims(env_state.h, axis=1).T))
        
        V_reach_append = jnp.concatenate((traj_batch.value_reach, jnp.expand_dims(last_val_reach, axis=1).T))
        V_append = jnp.concatenate((traj_batch.value, jnp.expand_dims(last_val, axis=1).T))


        done = calculate_done(traj_batch, h_append)
        done = done[:-1, :]


  
        # Build success-based phi supervision with mask using g and h
        phi_targets, phi_mask = calculate_phi_targets_success(ent_gamma[1], g_append[1:, :], h_append[1:, :])

 
        advantages_reach, targets_reach = calculate_gae_reach4(ent_gamma[1], config["GAE_LAMBDA"], g_append, V_reach_append, done, h_append)
        advantages_V, targets_V = calculate_gae2(config["GAMMA_ENERGY"], config["GAE_LAMBDA"], traj_batch, done, last_val)

        # UPDATE NETWORK
        update_state = (train_state_policy, train_state_energy, train_state_reach, train_state_phi,
                        traj_batch, advantages_reach, targets_reach, advantages_V, targets_V, phi_targets, phi_mask, rng)

        xs = jnp.ones(config["UPDATE_EPOCHS"]) * ent_gamma[0]
        update_state, loss_info = jax.lax.scan(
            update_epoch, update_state, xs, config["UPDATE_EPOCHS"]
        )
        train_state_policy = update_state[0]
        train_state_energy = update_state[1]
        train_state_reach = update_state[2]
        train_state_phi = update_state[3]
        rng = update_state[-1]

        return ((train_state_policy, train_state_energy, train_state_reach, train_state_phi, rng),
                {"batch_info": (traj_batch, targets_reach, targets_V, done), "loss_info": loss_info,
                 "reach_gamma": ent_gamma[1], "entropy_weight": ent_gamma[0]})

    update_epoch = partial(_rapcppo_update, config)
    env_step = partial(_env_step, env, env_params)
    training = jax.jit(_train)


    policy_lr = config["POLICY_LR"]      
    value_lr = config["VALUE_LR"]     
    lagrange_lr = config["LAGRANGE_LR"]  
    phi_lr = config["PHI_LR"]                     

    # INIT POLICY NETWORK
    if config["DISCRETE"] == False:
        policy_network = Policy_Network(
            env.action_space(env_params).shape[0], 
            activation=config["POLICY_ACTIVATION"],
            network_depth=config.get("POLICY_NETWORK_DEPTH", 2),
            hidden_width=config.get("HIDDEN_LAYER_WIDTH", 256) 
        )
    else:
        policy_network = Policy_Network_Discrete(
            env.action_space(env_params).n, 
            activation=config["POLICY_ACTIVATION"],
            network_depth=config.get("POLICY_NETWORK_DEPTH", 2),
            hidden_width=config.get("HIDDEN_LAYER_WIDTH", 256)
        )
    rng, _rng = jax.random.split(rng)
    init_x = jnp.zeros(env.observation_space(env_params).shape)
    network_params_policy = policy_network.init(_rng, init_x)
    policy_tx = optimizer(config, lr=policy_lr)
    train_state_policy = TrainState.create(
        apply_fn=policy_network.apply,
        params=network_params_policy,
        tx=policy_tx,
        mean=jnp.zeros(env.observation_space(env_params).shape),
        variance=jnp.zeros(env.observation_space(env_params).shape),
        count=1e-4,
    )

    # INIT VALUE ENERGY NETWORK
    value_network_energy = Value_Network(
        activation=config["VALUE_ACTIVATION"],
        network_depth=config.get("VALUE_NETWORK_DEPTH", 2),  
        hidden_width=config.get("HIDDEN_LAYER_WIDTH", 256) 
    )
    rng, _rng = jax.random.split(rng)
    init_x = jnp.zeros(env.observation_space(env_params).shape)
    network_params_energy = value_network_energy.init(_rng, init_x)
    value_tx = optimizer(config, lr=value_lr)
    train_state_energy = TrainState.create(
        apply_fn=value_network_energy.apply,
        params=network_params_energy,
        tx=value_tx,
        mean=jnp.zeros(env.observation_space(env_params).shape),
        variance=jnp.zeros(env.observation_space(env_params).shape),
        count=1e-4,
    )

    # INIT VALUE FIND NETWORK
    value_network_reach = Value_Network(
        activation=config["VALUE_ACTIVATION"],
        network_depth=config.get("VALUE_NETWORK_DEPTH", 2), 
        hidden_width=config.get("HIDDEN_LAYER_WIDTH", 256)  
    )
    rng, _rng = jax.random.split(rng)
    init_x = jnp.zeros(env.observation_space(env_params).shape)
    network_params_reach = value_network_reach.init(_rng, init_x)
    train_state_reach = TrainState.create(
        apply_fn=value_network_reach.apply,
        params=network_params_reach,
        tx=value_tx,
        mean=jnp.zeros(env.observation_space(env_params).shape),
        variance=jnp.zeros(env.observation_space(env_params).shape),
        count=1e-4,
    )

    # No Lagrange network in this variant

    # INIT PHI NETWORK
    # phi_network = Phi_Network(
    phi_network = Value_Network(
        network_depth=config.get("PHI_NETWORK_DEPTH", 2),  
        hidden_width=config.get("HIDDEN_LAYER_WIDTH", 256),  
        activation=config.get("PHI_ACTIVATION", "tanh")
    )
    rng, _rng = jax.random.split(rng)
    init_x = jnp.zeros(env.observation_space(env_params).shape)
    network_params_phi = phi_network.init(_rng, init_x)
    phi_tx = optimizer(config, lr=phi_lr)
    train_state_phi = TrainState.create(
        apply_fn=phi_network.apply,
        params=network_params_phi,
        tx=phi_tx,
        mean=jnp.zeros(env.observation_space(env_params).shape),
        variance=jnp.zeros(env.observation_space(env_params).shape),
        count=1e-4,
    )

    total_timesteps = config["NUM_UPDATES"] // config["STEP_SCAN"]

    for timestep in range(config["NUM_UPDATES"] // config["STEP_SCAN"]):

        t0 = time.time()

        xs = jnp.zeros((config["STEP_SCAN"], 2))

        if config['ANNEAL_ENT'] == True:
            ent = jnp.ones(config["STEP_SCAN"]) * config["ENT_COEF"] * (total_timesteps - timestep) / total_timesteps
        else:
            ent = jnp.ones(config["STEP_SCAN"]) * config["ENT_COEF"]

        gamma_1 = jnp.ones(config["STEP_SCAN"]) * config["GAMMA_REACH_INIT"] + (config['GAMMA_REACH_FINAL'] - config["GAMMA_REACH_INIT"]) * timestep / total_timesteps
        gamma_2 = jnp.ones(config["STEP_SCAN"]) * jnp.minimum(config['GAMMA_REACH_FINAL'], config["GAMMA_REACH_INIT"] +
                              (config['GAMMA_REACH_FINAL'] - config["GAMMA_REACH_INIT"]) * timestep * 2 / total_timesteps)

        xs = xs.at[:, 0].set(ent)
        xs = xs.at[:, 1].set(gamma_2)

        update_state, result = jax.lax.scan(
            training, (train_state_policy, train_state_energy, train_state_reach, train_state_phi, rng),
            xs, config["STEP_SCAN"]
        )

        train_state_policy, train_state_energy, train_state_reach, train_state_phi, rng = update_state

        loss_info = result['loss_info']

        result_traj = tree_index1(result['batch_info'], 0)
        
        traj_batch, targets_reach, targets_V, done = result_traj

        consumption, success_rate, idx = calculate_consumption(traj_batch)


        # checkpoints.save_checkpoint(ckpt_dir='/home/panjd/py_projects/RC/RAPC-PPO/networks/{}/'.format(config["DIR"]),
        #                             target={"policy_network":train_state_policy, 
        #                                    "energy_network":train_state_energy,
        #                                    "reach_network":train_state_reach,
        #                                    "phi_network":train_state_phi},
        #                             step=timestep,
        #                             overwrite=True,
        #                             keep_every_n_steps=20)

        t1 = time.time()

        # wandb.log({
        #     "task_success_rate": success_rate, "average_energy_consumption": np.mean(consumption),
        #            "actor_loss": jnp.mean(loss_info["actor_loss"]), "entropy_loss": jnp.mean(loss_info["entropy_loss"]),
        #            "energy_loss": jnp.mean(loss_info["energy_loss"]), "reach_loss": jnp.mean(loss_info["reach_loss"]),
        #            "phi_loss": jnp.mean(loss_info["phi_loss"]),
        #            "reach_gamma": result['reach_gamma'][0], "entropy_weight": result['entropy_weight'][0]})
        print("Training step {}: Success Rate: {}  Avg Energy: {}".format(timestep, success_rate, np.mean(consumption)))
        print("Time {}".format(t1-t0))

    return

if __name__ == "__main__":
    config = vars(get_args(sys.argv[1:]))
    config["NUM_UPDATES"] = int(
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = int(
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
    # os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"
    # os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = ".99"
    os.environ["CUDA_VISIBLE_DEVICES"] = config['CUDA_USE']
    folder = os.path.exists("model/{}".format(config['DIR']))
    if not folder:
        os.makedirs("model/{}".format(config['DIR']))
        os.makedirs("model/{}/reach".format(config['DIR']))
        os.makedirs("model/{}/policy".format(config['DIR']))
        os.makedirs("model/{}/value".format(config['DIR']))
        os.makedirs("model/{}/total".format(config['DIR']))
        os.makedirs("model/{}/target".format(config['DIR']))
        os.makedirs("model/{}/value_target".format(config['DIR']))
        os.makedirs("model/{}/state_traj".format(config['DIR']))
    env = get_env(config)
    wandb.init(project='EC-EFPPO-{}'.format(config["EXP_NAME"]), name=config["NAME"], config=config)
    env_params = env.default_params
    if config['EXP_NAME'] == 'WindField':
        env_params = env_params.replace(index=config['SECTION'])
    rng = jax.random.PRNGKey(10) #20 
    out = train(env, env_params, config, rng)
