import argparse
import gym
import gym_hybrid
import numpy as np 
import torch
import wandb
from easydict import EasyDict
from tqdm import tqdm 

from src.chpo.Chpo_core import combined_shape
from src.chpo.Chpo_core import discount_cumsum
from src.chpo.Chpo_core import statistics_scalar
from src.chpo.Chpo_model import CHPO_Model
from src.chpo.Chpo_policy import CHPOPolicy



class CHPOBuffer:
    """
    A buffer for storing trajectories experienced by a CHPO agent interacting
    with the environment, and using Generalized Advantage Estimation (GAE-Lambda)
    for calculating the advantages of state-action pairs.
    """

    def __init__(self, obs_dim, discrete_act_dim, parameter_act_dim, size, device='cpu', gamma=0.99, lam=0.95):
        self.obs_buf = np.zeros((size, obs_dim), dtype=np.float32)  
        self.next_obs_buf = np.zeros((size, obs_dim), dtype=np.float32)  
        self.discrete_act_buf = np.zeros(size, dtype=np.int64) 
        self.parameter_act_buf = np.zeros((size,parameter_act_dim), dtype=np.float32)
        self.adv_buf = np.zeros(size, dtype=np.float32)
        self.rew_buf = np.zeros(size, dtype=np.float32)
        self.ret_buf = np.zeros(size, dtype=np.float32)
        self.val_buf = np.zeros(size, dtype=np.float32)
        self.logp_discreate_act_buf = np.zeros(size, dtype=np.float32)
        self.logp_parameter_act_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.logit_action_type_buf = np.zeros((size,discrete_act_dim), dtype=np.float32)
        self.logit_action_argsmu_buf = np.zeros((size,parameter_act_dim), dtype=np.float32)
        self.logit_action_argssigma_buf = np.zeros((size,parameter_act_dim), dtype=np.float32)
        self.costadv_buf = np.zeros(size, dtype=np.float32)
        self.cost_buf = np.zeros(size, dtype=np.float32)
        self.costret_buf = np.zeros(size, dtype=np.float32)
        self.costval_buf = np.zeros(size, dtype=np.float32)    
        self.gamma, self.lam = gamma, lam
        self.ptr, self.path_start_idx, self.max_size = 0, 0, size
        self.device = device
        
        self.epret_cost_buf = np.zeros(size, dtype=np.float32)

    def store(self, obs, next_obs, discrete_act, parameter_act, rew, val, logp_discrete_act, logp_parameter_act, done, logit_action_type, logit_action_argsmu, logit_action_argssigma,cost,cost_value):
        """
        Append one timestep of agent-environment interaction to the buffer.
        """
        assert self.ptr < self.max_size     # buffer has to have room so you can store
        self.obs_buf[self.ptr] = obs
        self.next_obs_buf[self.ptr] = next_obs
        self.discrete_act_buf[self.ptr] = discrete_act
        self.parameter_act_buf[self.ptr] = parameter_act
        self.rew_buf[self.ptr] = rew
        self.val_buf[self.ptr] = val
        self.logp_discreate_act_buf[self.ptr] = logp_discrete_act
        self.logp_parameter_act_buf[self.ptr] = logp_parameter_act
        self.done_buf[self.ptr] = done
        self.logit_action_type_buf[self.ptr] = logit_action_type
        self.logit_action_argsmu_buf[self.ptr] = logit_action_argsmu
        self.logit_action_argssigma_buf[self.ptr] = logit_action_argssigma
        self.cost_buf[self.ptr] = cost
        self.costval_buf[self.ptr] = cost_value
        self.ptr += 1

    def finish_path(self, last_val=0, last_val_cost = 0):
        """
        Call this at the end of a trajectory, or when one gets cut off
        by an epoch ending. This looks back in the buffer to where the
        trajectory started, and uses rewards and value estimates from
        the whole trajectory to compute advantage estimates with GAE-Lambda,
        as well as compute the rewards-to-go for each state, to use as
        the targets for the value function.

        The "last_val" argument should be 0 if the trajectory ended
        because the agent reached a terminal state (died), and otherwise
        should be V(s_T), the value function estimated for the last state.
        This allows us to bootstrap the reward-to-go calculation to account
        for timesteps beyond the arbitrary episode horizon (or epoch cutoff).
        """

        path_slice = slice(self.path_start_idx, self.ptr)
        rews = np.append(self.rew_buf[path_slice], last_val)
        vals = np.append(self.val_buf[path_slice], last_val)
        costs = np.append(self.cost_buf[path_slice], last_val_cost)
        
        # the next two lines implement GAE-Lambda advantage calculation
        deltas = rews[:-1] + self.gamma * vals[1:] - vals[:-1]
        self.adv_buf[path_slice] = discount_cumsum(deltas, self.gamma * self.lam)
        
        # the next line computes rewards-to-go, to be targets for the value function
        self.ret_buf[path_slice] = discount_cumsum(rews, self.gamma)[:-1]
        
        self.epret_cost_buf[path_slice] = discount_cumsum(costs, 1)[:-1]
        
        self.path_start_idx = self.ptr

    def get(self):
        """
        Call this at the end of an epoch to get all of the data from
        the buffer, with advantages appropriately normalized (shifted to have
        mean zero and std one). Also, resets some pointers in the buffer.
        """
        assert self.ptr == self.max_size    # buffer has to be full before you can get
        self.ptr, self.path_start_idx = 0, 0
        # the next two lines implement the advantage normalization trick
        adv_mean, adv_std = statistics_scalar(self.adv_buf)
        self.adv_buf = (self.adv_buf - adv_mean) / adv_std
        data = dict(
            obs=self.obs_buf, 
            next_obs = self.next_obs_buf,
            discrete_act=self.discrete_act_buf,
            parameter_act=self.parameter_act_buf,
            ret=self.ret_buf,
            adv=self.adv_buf, 
            logp_discrete_act = self.logp_discreate_act_buf,
            logp_parameter_act = self.logp_parameter_act_buf,
            done = self.done_buf,
            reward = self.rew_buf,
            logit_action_type = self.logit_action_type_buf,
            logit_action_argsmu = self.logit_action_argsmu_buf,
            logit_action_argssigma = self.logit_action_argssigma_buf,
            cost_ret=self.costret_buf,
            adc=self.costadv_buf,
            cost=self.cost_buf,
            epret_cost = self.epret_cost_buf
        )
        return {k: torch.as_tensor(v, dtype=torch.float32, device=self.device) for k,v in data.items()}


def set_seed(seed):
    '''
    Set the seed for the torch and numpy
    '''
    torch.manual_seed(seed)
    np.random.seed(seed)

    
def main(args):
    '''
    The main function for the train, evaluation
    '''
    if args.wandb:
        project_name = 'chpo'
        run = wandb.init(project = project_name)
    # Set the seed
    set_seed(args.seed)
    # Crete the environment
    env = gym.make(args.env)
    env.seed(args.seed)
    OBSERVATION_SPACE = env.observation_space.shape[0]
    ACTION_SPACE = env.action_space[0].n
    PARAMETERS_SPACE = env.action_space[1].shape[0]
    

    torch.set_num_threads(3)

    # Define the replay buffer
    buf = CHPOBuffer(
        obs_dim = OBSERVATION_SPACE,
        discrete_act_dim = ACTION_SPACE,
        parameter_act_dim = PARAMETERS_SPACE,
        size = args.steps_per_epoch,
        device = args.device
    )

    # Define the CHPO Model 
    chpo_model = CHPO_Model(
        obs_shape = OBSERVATION_SPACE,
        discrete_act_dim = ACTION_SPACE,
        parameter_act_dim = PARAMETERS_SPACE,
        share_encoder = args.share_encoder,
        encoder_hidden_size_list = args.encoder_hidden_size_list,
        sigma_type = args.sigma_type,
        fixed_sigma_value = args.fixed_sigma_value,
        bound_type = args.bound_type,
    )
    # print(OBSERVATION_SPACE)

    # Define the CHPO Policy
    chpo_policy = CHPOPolicy(
        env_id = args.env,
        buf = buf,
        model = chpo_model,
        device = args.device,
        adv_norm = args.adv_norm,
        value_norm= args.value_norm,
        wandb_flag = args.wandb,
        env = env,
        share_encoder=args.share_encoder,
        batch_size = args.batch_size,
        cost_limit = args.cost_limit,
        cost_distance = args.cost_distance,
        seed = args.seed,
        rc_ratio = args.rc_ratio,
    )

    
    # Train、Evualtion、collect for the CHPO policy
    for epoch in tqdm(range(args.max_train_epochs), desc='Train Loop:'):

        # evaluation the CHPO policy
        chpo_policy.evaluate(
            eval_epoch = args.eval_epoch,
            step = epoch,
        )

        # collect the sample and store to buffer
        ep_len = chpo_policy.rollout(
            steps_per_epoch = args.steps_per_epoch
        )
        
        # update the CHPO policy
        chpo_policy.update(
            data = buf,
            train_iters = args.train_steps_per_epoch,
            train_epoch = epoch,
            train_all_epoch = args.max_train_epochs,
            ep_len = ep_len,
        )


if __name__ =='__main__':
    parser = argparse.ArgumentParser()
    # set the environment
    parser.add_argument('--env', type=str, default='Moving-v0') #Moving-v0 Sliding-v0 HardMove-v0 Perpendicular_safe-v0
    # set the parameter for the training
    parser.add_argument('--steps_per_epoch', type=int, default=3200, help='The epoch number of the buffer')
    parser.add_argument('--max_train_epochs', type=int, default=4000, help='The max epochs of the train')
    parser.add_argument('--train_steps_per_epoch', type=int, default=10, help='The train step of each epoch')
    parser.add_argument('--batch_size', type=int, default=320, help='The batch size during train' )
    parser.add_argument('--random_epochs', type=int, default=3, help='get the sample by the random policy')
    # Set the parameter for the network
    parser.add_argument('--encoder_hidden_size_list', type=list, default=[256, 128, 64, 64], help='The hidden size for the encoder network')
    parser.add_argument('--sigma_type', type=str, default='fixed')
    parser.add_argument('--fixed_sigma_value', type=float, default=0.3)
    parser.add_argument('--bound_type', type=str, default='tanh')
    # Set the others parameter
    parser.add_argument('--eval_epoch', type=int, default=5)
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--device', default='cpu', help='Set the training device')
    parser.add_argument('--wandb','-wan',action='store_true',help='Flag for logging data via wandb')
    parser.add_argument('--share_encoder',type=bool, default=False)
    parser.add_argument('--cost_limit',type=float, default=1.0)
    parser.add_argument('--value_norm',type=bool, default=True)
    parser.add_argument('--adv_norm',type=bool, default=True)
    parser.add_argument('--cost_distance',type=float, default=0)
    parser.add_argument('--rc_ratio',type=float, default=4)

    args = parser.parse_args()

    main(args)
    print('success!')
