import time
import datetime
import torch
import numpy as np 
import itertools
import gym
import json
from functools import reduce
from copy import deepcopy
from argparse import Namespace
from config import parser_main
from tensorboardX import SummaryWriter
from algo.utils.utils import log_loss, log_reward

customed_args = vars(parser_main().parse_args())
with open("common.json") as f:
    json_config = json.load(f)
args_dict = {}
args_dict.update(json_config["common"])
args_dict.update(customed_args)
args_dict.update(json_config["diayn"])
args_dict.update(json_config["sac"])
args = Namespace(**args_dict)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

# TensorboardX
logdir = 'logs/url_{}_{}'.format(args.scenario, datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
writer = SummaryWriter(logdir=logdir)

# Load environment
env = gym.make(args.scenario)
env.seed(args.seed)
env.action_space.seed(args.seed)
obs_shape_list = env.observation_space.shape
obs_shape = reduce((lambda x,y: x*y), obs_shape_list)
action_space = env.action_space
discrete_action = hasattr(action_space, 'n')

# Initialize base algorithm
update_interval = 1
updates_per_step = 1
max_num_episodes = 1000
timesteps = 0
updates = 0
rc = args.r_scale # coefficient for native reward
src = args.sr_scale # coefficient for pseudo reward
print("Reward scale for pseudo reward: {}".format(src))
print("Reward scale for external reward: {}".format(rc))
running_r = []
running_sr = []
running_reward = []
running_ep_length = []
running_acc = []

if args.sr_algo == "diayn":
    from algo.diayn2 import DiaynTrainer
    trainer = DiaynTrainer(obs_shape, action_space, rc, src, args)
elif args.sr_algo == "wasserstein":
    from algo.wasserstein import WassersteinTrainer
    trainer = WassersteinTrainer(obs_shape, action_space, rc, src, args)
else:
    print("Argument for sr_algo should be in (diayn / wasserstein)!")
    exit(0)

# Training loop
for i_episode in itertools.count(1):
    obs = env.reset()
    episode_r = 0
    episode_sr = 0
    episode_reward = 0
    episode_steps = 0
    # if i_episode > 20:
    #     label = np.random.randint(0, high=args.num_modes)
    # else:
    #     label = i_episode % args.num_modes
    label = i_episode % args.num_modes

    # Trainer clearup before an epsiode starts: label setting
    trainer.start_episode(label)
    for t in range(args.max_episode_len):
        # Sample actions
        # if i_episode > 20:
        if len(trainer.current_memory) > args.start_steps:
            action, logprob = trainer.act(obs) # No need to use wrapped obs in indie settings
            if len(action.shape) > 1:
                action = action[0]
            if discrete_action:
                action = action[0]
        else:
            action = env.action_space.sample()
            logprob = np.array([1.0])  

        # Train
        # if i_episode > 20:
        if trainer.can_update_policy():
            if timesteps % update_interval == 0:
                for _ in range(args.updates_per_step):
                    c1_loss, c2_loss, p_loss, ent_loss, alpha = trainer.update_policy(updates)  
                    d_loss = trainer.update_disc()
                    log_loss(writer, (c1_loss, c2_loss,p_loss, ent_loss, alpha, d_loss), updates)
                    updates += 1

        # Execute
        new_obs, reward, done, _ = env.step(action.tolist())
        timesteps += 1
        episode_steps += 1
        if hasattr(env, 'max_steps'):
            mask = 1 if episode_steps == env.max_steps else float(not done)
        else:
            mask = float(not done)
        
        # Record transitions
        trainer.record(obs, action, logprob, reward, new_obs, mask)
        obs = new_obs
        if done:
            break
    # Trainer works after an episode ends
    episode_r, episode_sr, episode_reward, acc = trainer.end_episode()

    # Console log
    log_reward(writer, (episode_r, episode_sr, episode_reward, t+1), i_episode)
    running_r.append(episode_r)
    running_sr.append(episode_sr)
    running_reward.append(episode_reward)
    running_ep_length.append(t+1)
    running_acc.append(acc)
    print("Episode {}: return {}".format(i_episode, episode_r))

    if i_episode % args.log_interval == 0:
        L = args.log_interval
        T = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        avg_r = int(sum(running_r[-L:])/L)
        avg_sr = int(sum(running_sr[-L:])/L)
        avg_reward = int(sum(running_reward[-L:])/L)
        avg_length = int(sum(running_ep_length[-L:])/L)
        avg_acc = sum(running_acc[-L:])/L
        print("[{}] Current episode: {}, avg episode length: {}".format(T, i_episode, avg_length))
        print("[{}] Avg return: {}, avg r: {}, avg sr: {}, avg acc: {:.2f}".format(T, avg_reward, avg_r, avg_sr, avg_acc))
    
    if i_episode > args.num_episodes * args.num_modes:
        break

trainer.save_models()
env.close()