import numpy as np
import torch
#import gym
import argparse
import os
import yaml
from tools4mujoco import TD3, New_Trans_RB, env_constructor
import copy

import copy
########################################################

from torch.utils.tensorboard import SummaryWriter

device = "cuda" if torch.cuda.is_available() else "cpu"


evaluations = []


def eval_transformer_2stage(policy, args, eval_episodes=10, context=None):
    
    eval_env, state_dim, action_dim = env_constructor(args.env, seed=args.seed+100, obs_indices=args.obs_indices)
    avg_reward = 0.
    policy.trans.eval()
    percs1 = []
    percs2 = []
    

    for _ in range(eval_episodes):
        
        eval_env.seed(args.seed+100)
        state = eval_env.reset()
        done = False
        
        st_state = state            #state (n_e, s_d)
        st_states = copy.deepcopy(st_state).unsqueeze(1).to(device=device, dtype=torch.float32)  #state (n_e, 1, s_d)
        if args.obs_mode != 'state':
            img_state = state[f"{args.obs_mode}"]   #img_state (n_e, 128, 128, 4/3)
            img_states = copy.deepcopy(img_state).unsqueeze(1).to(device=device, dtype=torch.float32)  #img_states (n_e, 1, 128, 128, 4/3)
        
        t = -1
        while done == False:
            t+=1
            if st_states.shape[1] > context : 
                st_states = st_states[:, -context:, :]
                if args.obs_mode != 'state':
                    img_states = img_states[:, -context:, :]

            st_s = st_states.unsqueeze(1)             #st_s (n_e, 1, cont, s_d)
            if args.obs_mode != 'state':
                img_s = img_states.unsqueeze(1)              #img_s (n_e, 1, cont, 128, 128, 4/3)
            
            sampled_action = policy.trans.actor_forward(st_s, show_percentage=False) if args.obs_mode == 'state' else policy.trans.actor_forward(st_s, img_s)
            sampled_action = sampled_action.detach().cpu()
            action = np.clip( sampled_action.numpy()[:,0,], -1, 1)  #action (n_e, a_d)
            
            state, r, done, info = eval_env.step( action[0] )
            
            avg_reward += r.item()
            st_state = state 
            st_cur_state = copy.deepcopy(st_state).unsqueeze(1).to(device=device, dtype=torch.float32)  #cur_state (n_e, 1, s_d)
            st_states = torch.cat([st_states, st_cur_state], dim=1)                                          #states (n_e, cont+1, s_d)
            if args.obs_mode != 'state':
                img_state = state[f"{args.obs_mode}"]   #img_state (n_e, 128, 128, 4/3)
                img_cur_state = copy.deepcopy(img_state).unsqueeze(1).to(device=device, dtype=torch.float32)  #img_cur_state (n_e, 1, 128, 128, 4/3)
                img_states = torch.cat([img_states, img_cur_state], dim=1)                                          #img_states (n_e, cont, 128, 128, 4/3)

        
            
            

    avg_reward /= eval_episodes

    print("---------------------------------------")
    print(f"Env={args.env} | Seed={args.seed} | Evaluation over {eval_episodes} episodes: {avg_reward:.3f}")
    print("---------------------------------------")
    
    policy.trans.train()
    
    return avg_reward


def second_stage(config=None, args=None):
        
        
    test_env, state_dim, action_dim = env_constructor(args.env, seed=args.seed, obs_indices=args.obs_indices)
        
        
    kwargs = {
            "state_dim": state_dim,
            "action_dim": action_dim,
            "max_action": 1,
            "discount": args.discount,
            "tau": args.tau,
    }
        
        
    n_l = config['model_config']['num_layers'] = 1
    d_m = config['model_config']['d_model'] = 256
    n_h = config['model_config']['n_heads'] = 2
    d_f = config['model_config']['dim_feedforward'] = 512
    crtc = config['model_config']['critic_mode'] = 'FC'
    
        
    if args.policy == "TD3":
        kwargs["policy_noise"] = args.policy_noise * 1
        kwargs["noise_clip"] = args.noise_clip * 1
        kwargs["policy_freq"] = args.policy_freq
        kwargs["grad_clip"] = args.grad_clip
            
    policy = TD3(args.num_envs, args.obs_mode ,config['train_config']['context_length'], config['model_config'], **kwargs)
            
    
        
         
    pomdp = True if args.obs_indices != None else False    
    path2run = f"MuJoCo_runs/{args.env}/seed={args.seed}"
    experiment = SummaryWriter(log_dir=path2run)
        
        
    
    env, state_dim, action_dim = env_constructor(args.env, seed=args.seed, obs_indices=args.obs_indices)
    env.action_space.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

        
    context_length = config['train_config']['context_length']

    policy.new_trans_RB = New_Trans_RB(args.num_envs, config['train_config']['replay_buffer_size'], context_length, state_dim, action_dim, args.obs_mode)
    policy.experiment = experiment
    policy.trans.train()
    if crtc == 'FC':
        policy.critic.train()

    out_states = []
    if args.obs_mode != 'state':
        out_img_states = []
    out_actions = []
    out_rewards = []
    out_dones = []
        
        
    state = env.reset()
        
    st_state = state
    if args.obs_mode != 'state':
        img_state = state[f"{args.obs_mode}"]
        
    out_states.append( st_state )
    if args.obs_mode != 'state':
        out_img_states.append(img_state)
        
    st_states = copy.deepcopy(st_state).unsqueeze(1).to(device=device, dtype=torch.float32)  #state (n_e, 1, s_d)
    if args.obs_mode != 'state':
        img_states = copy.deepcopy(img_state).unsqueeze(1).to(device=device, dtype=torch.float32)
        
    avg_ret = 0.
    avg_reward = 0.
    episode_num = 0
    eval_counter = 0
    done = False

    for t in range(int(args.max_timesteps)):

        if st_states.shape[1] > context_length : 
            st_states = st_states[:, -context_length:, :]
            if args.obs_mode != 'state':
                img_states = img_states[:, -context_length:, :]    

        st_s = st_states.unsqueeze(1)    # n_e, 1, cont, s_d
        if args.obs_mode != 'state':
            img_s = img_states.unsqueeze(1) 
            
        if t < args.start_timesteps:
            action = env.action_space.sample()  #sampled_action=arr(n_e, a_d)
                
            out_actions.append(torch.Tensor(action).unsqueeze(0))
        else:
            sampled_action = policy.trans.actor_forward(st_s) if args.obs_mode == 'state' else policy.trans.actor_forward(st_s, img_s)
            sampled_action = sampled_action.detach().cpu()[:,0,]            #sampled_action=Tensor(n_e, a_d)
                
            exploration_noise = torch.normal(mean=0.0, std=args.expl_noise, size=sampled_action.shape)
            sampled_action += exploration_noise
                
            out_actions.append(sampled_action)
            action = np.clip( sampled_action.numpy()[0], -1, 1)                #action=Arr(n_e, a_d) 

            
        state, reward, done, info = env.step( action )
        avg_ret += reward.item()

            
        out_dones.append(done.to(float).reshape(-1, 1))    #truncated.reshape(-1, 1) = (n_e, 1)
        out_rewards.append(reward.reshape(-1, 1))               #reward.reshape(-1, 1) = (n_e, 1)
            
        st_state = state#['state']
        out_states.append( state )#['state']
        if args.obs_mode != 'state':
            img_state = state[f"{args.obs_mode}"]
            out_img_states.append(img_state)

        if t >= context_length-1 and st_states.shape[1] == context_length:
            states2RB = out_states[-context_length-1:-1]
            if args.obs_mode != 'state':
                img_states2RB = out_img_states[-context_length-1:-1]
                img_next_states2RB = out_img_states[-context_length:]
            act2RB = out_actions[-1]
            ret2RB = out_rewards[-1]
            done2RB = out_dones[-1]
            next_states2RB = out_states[-context_length:]
                
            if args.obs_mode == 'state':
                policy.new_trans_RB.recieve_traj(states2RB, act2RB, ret2RB, done2RB, next_states2RB)
            else:    
                policy.new_trans_RB.recieve_traj(states2RB, act2RB, ret2RB, done2RB, next_states2RB, img_states2RB, img_next_states2RB)

        st_cur_state = copy.deepcopy(st_state).unsqueeze(1).to(device=device, dtype=torch.float32)  # (n_e, 1, s_d)
        st_states = torch.cat([st_states, st_cur_state], dim=1)
        if args.obs_mode != 'state':
            img_cur_state = copy.deepcopy(img_state).unsqueeze(1).to(device=device, dtype=torch.float32)  #img_cur_state (n_e, 1, 128, 128, 4/3)
            img_states = torch.cat([img_states, img_cur_state], dim=1)   
            
        if t >= args.start_timesteps:
            policy.stage_2_train(args.batch_size)
                

        if done:
            print(f"Total T: {t+1} Episode Num: {episode_num+1} Reward: {avg_ret:.3f}")
            avg_ret = 0
                
            out_states = []
            if args.obs_mode != 'state':
                out_img_states = []
            out_actions = []
            out_rewards = []
            out_dones = []
            done = False
                
                
            state = env.reset()

            st_state = state#['state']
            if args.obs_mode != 'state':
                img_state = state[f"{args.obs_mode}"]
                
            out_states.append( st_state )
            if args.obs_mode != 'state':
                out_img_states.append(img_state)
                
            st_states = copy.deepcopy(st_state).unsqueeze(1).to(device=device, dtype=torch.float32)  #state (n_e, 1, s_d)
            if args.obs_mode != 'state':
                img_states = copy.deepcopy(img_state).unsqueeze(1).to(device=device, dtype=torch.float32)

            episode_num += 1


            # Evaluate episode
        if ((t + 1) % args.eval_freq == 0) and (t >= args.start_timesteps):
            eval_counter += 1
            avg_reward  = eval_transformer_2stage(policy, args, 3, context_length)
            experiment.add_scalar('Eval_reward', avg_reward, t)
            
            
    experiment.close()

if __name__ == "__main__":

    
    with open("sh_config.yaml") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    
    
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--policy", default="TD3")                  
    parser.add_argument("--env", default="HalfCheetah-v4")          
    parser.add_argument("--obs_indices", default=None) #REDUCE MDP ENV TO POMDP   #Cth [0,1,2,3,8,9,10,11,12] | Hppr [0,1,2,3,4] | Ant [0,1,2,3,4,5,6,7,8,9,10,11,12]
    parser.add_argument("--obs_mode", default="state") 
    parser.add_argument("--num_envs", default=1, type=int)
    parser.add_argument("--seed", default=3, type=int)
    parser.add_argument("--trans_critic", default=False)
    parser.add_argument("--separate_trans_critic", default=False)
    parser.add_argument("--start_timesteps", default=25000, type=int) 
    parser.add_argument("--eval_freq", default=2e3, type=int)      
    parser.add_argument("--max_timesteps", default=1300000, type=int)  
    parser.add_argument("--grad_clip", default=1000000, type=int)
    parser.add_argument("--expl_noise", default=0.1, type=float)    
    parser.add_argument("--batch_size", default=256, type=int)     
    parser.add_argument("--discount", default=0.99, type=float)     
    parser.add_argument("--tau", default=0.007, type=float)         
    parser.add_argument("--policy_noise", default=0.2)              
    parser.add_argument("--noise_clip", default=0.2)                
    parser.add_argument("--policy_freq", default=2, type=int) 
    args = parser.parse_args()
    
    
    
    second_stage(config=config, args=args)
                        