import time
import datetime
import torch
import numpy as np 
import itertools
import gym
import json
import argparse
from functools import reduce
from copy import deepcopy
from argparse import Namespace
from config import parser_main
from tensorboardX import SummaryWriter
from algo.utils.utils import log_loss, log_reward

parser = argparse.ArgumentParser(description='PyTorch Soft Actor-Critic Args')
parser.add_argument('--scenario', default="HalfCheetah-v2",
                    help='Mujoco Gym environment (default: HalfCheetah-v2)')
parser.add_argument('--policy', default="Gaussian",
                    help='Policy Type: Gaussian | Deterministic (default: Gaussian)')
parser.add_argument('--eval', type=bool, default=True,
                    help='Evaluates a policy a policy every 10 episode (default: True)')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G',
                    help='discount factor for reward (default: 0.99)')
parser.add_argument('--tau', type=float, default=0.005, metavar='G',
                    help='target smoothing coefficient(τ) (default: 0.005)')
parser.add_argument('--critic_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--policy_lr', type=float, default=0.0003, metavar='G',
                    help='learning rate (default: 0.0003)')
parser.add_argument('--alpha', type=float, default=0.2, metavar='G',
                    help='Temperature parameter α determines the relative importance of the entropy\
                            term against the reward (default: 0.2)')
parser.add_argument('--automatic_entropy_tuning', type=bool, default=False, metavar='G',
                    help='Automaically adjust α (default: False)')
parser.add_argument('--seed', type=int, default=123456, metavar='N',
                    help='random seed (default: 123456)')
parser.add_argument('--batch_size', type=int, default=256, metavar='N',
                    help='batch size (default: 256)')
parser.add_argument('--num_steps', type=int, default=1000001, metavar='N',
                    help='maximum number of steps (default: 1000000)')
parser.add_argument('--hidden_dim', type=int, default=256, metavar='N',
                    help='hidden size (default: 256)')
parser.add_argument('--updates_per_step', type=int, default=1, metavar='N',
                    help='model updates per simulator step (default: 1)')
parser.add_argument('--start_steps', type=int, default=10000, metavar='N',
                    help='Steps sampling random actions (default: 10000)')
parser.add_argument('--target_update_interval', type=int, default=1, metavar='N',
                    help='Value target update per no. of updates per step (default: 1)')
parser.add_argument('--replay_size', type=int, default=1000000, metavar='N',
                    help='size of replay buffer (default: 10000000)')
parser.add_argument('--cuda', action="store_true",
                    help='run on CUDA (default: False)')
customed_args = vars(parser.parse_args())

with open("common.json") as f:
    json_config = json.load(f)
args_dict = {}
args_dict.update(json_config["common"])
args_dict.update(customed_args)
args_dict.update(json_config["diayn"])
# args_dict.update(json_config["sac"])
args = Namespace(**args_dict)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

# TensorboardX
logdir = 'logs/url_{}_{}'.format(args.scenario, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
writer = SummaryWriter(logdir=logdir)

# Load environment
env = gym.make(args.scenario)
env.seed(None)
obs_shape_list = env.observation_space.shape
obs_shape = reduce((lambda x,y: x*y), obs_shape_list)
action_space = env.action_space
discrete_action = hasattr(action_space, 'n')

# Initialize base algorithm
update_interval = 1
updates_per_step = 1
max_num_episodes = 1000
timesteps = 0
updates = 0
rc = 1.0 # coefficient for native reward
src = 0.0 # coefficient for pseudo reward
print("Reward scale for pseudo reward: {}".format(src))
print("Reward scale for external reward: {}".format(rc))
running_r = []
running_sr = []
running_reward = []
running_ep_length = []
running_acc = []

from algo.dummy import DiaynTrainer
trainer = DiaynTrainer(obs_shape, action_space, rc, src, args)

# Training loop
for i_episode in itertools.count(1):
    obs = env.reset()
    episode_r = 0
    episode_sr = 0
    episode_reward = 0
    episode_steps = 0
    if args.start_steps > timesteps:
        label = np.random.randint(0, high=args.num_modes)
    else:
        label = i_episode % args.num_modes

    # Trainer clearup before an epsiode starts: label setting
    trainer.start_episode(label)
    for t in range(args.max_episode_len):
        # Sample actions
        if args.start_steps < timesteps:
            action, logprob = trainer.act(obs) # No need to use wrapped obs in indie settings
            if len(action.shape) > 1:
                action = action[0]
            if discrete_action:
                action = action[0]
        else:
            action = env.action_space.sample()
            logprob = np.array([1.0])  

        # Train
        if trainer.can_update_policy():
            if timesteps % update_interval == 0:
                for _ in range(args.updates_per_step):
                    c1_loss, c2_loss, p_loss, ent_loss, alpha = trainer.update_policy(updates)  
                    d_loss = trainer.update_disc()
                    log_loss(writer, (c1_loss, c2_loss,p_loss, ent_loss, alpha, d_loss), updates)
                    updates += 1

        # Execute
        new_obs, reward, done, _ = env.step(action.tolist())
        timesteps += 1
        episode_steps += 1
        if hasattr(env, 'max_steps'):
            mask = 1 if episode_steps == env.max_steps else float(not done)
        else:
            mask = float(not done)
        
        # Record transitions
        trainer.record(obs, action, logprob, reward, new_obs, mask)
        obs = new_obs
        if done:
            break
    # Trainer works after an episode ends
    episode_r, episode_sr, episode_reward, acc = trainer.end_episode()

    # Console log
    log_reward(writer, (episode_r, episode_sr, episode_reward, t+1), i_episode)
    running_r.append(episode_r)
    running_sr.append(episode_sr)
    running_reward.append(episode_reward)
    running_ep_length.append(t+1)
    running_acc.append(acc)
    print("Episode {}: return {}".format(i_episode, episode_r))
    if i_episode % args.log_interval == 0:
        L = args.log_interval
        T = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        avg_r = int(sum(running_r[-L:])/L)
        avg_sr = int(sum(running_sr[-L:])/L)
        avg_reward = int(sum(running_reward[-L:])/L)
        avg_length = int(sum(running_ep_length[-L:])/L)
        avg_acc = sum(running_acc[-L:])/L
        print("[{}] Current episode: {}, avg episode length: {}".format(T, i_episode, avg_length))
        print("[{}] Avg return: {}, avg r: {}, avg sr: {}, avg acc: {:.2f}".format(T, avg_reward, avg_r, avg_sr, avg_acc))
    
    if i_episode > args.num_episodes:
        break

trainer.save_models()
env.close()