import numpy as np
import torch.nn
import wandb
import os

from datetime import date
from envs.sc2_env_wrapper import StarCraft2Env
from src.agents.LPagent_hier_continuous import LPAgent


TRAIN = True

map_name = "3s_vs_5z"

env = StarCraft2Env(map_name=map_name, window_size_x=400, window_size_y=300, enemy_obs=True)

state_dim = env.get_obs_size()
num_episodes = 2000000  # goal: 2 million timesteps; 15000 episodes approx.

agent_config = {"state_dim": state_dim,
                "n_ag": env.n_agents,
                "n_en": env.n_enemies,
                "action_dim": 5,
                "memory_len": 5000,
                "batch_size": 32,
                "train_start": 100,
                "epsilon_start": 1.0,
                "epsilon_decay": 1e-6,
                "gamma": 0.99,
                "hidden_dim": 32,
                "loss_ftn": torch.nn.MSELoss(),
                "lr": 5e-4,
                'memory_type': 'ep',
                'target_tau': 0.1,
                'target_update_interval': 200
                }

agent = LPAgent(**agent_config)
exp_name = date.today().strftime("%Y%m%d") + "_" + agent.name + map_name

dirName = 'result/{}'.format(exp_name)
if os.path.exists(dirName):
    i = 0
    while True:
        i += 1
        curr_dir = dirName + "_{}".format(i)
        if not os.path.exists(curr_dir):
            os.makedirs(curr_dir)
            break

else:
    curr_dir = dirName
    os.makedirs(dirName)

exp_conf = {'directory': curr_dir}

for e in range(num_episodes):
    env.reset()

    terminated = False
    episode_reward = 0
    ep_len = 0
    prev_killed_enemies = env.death_tracker_enemy
    prev_death_tracker = env.death_tracker_enemy
    if env.n_enemies > 1:
        high_action = None
        h_transition = True
    else:
        high_action = torch.tensor([0 for _ in range(env.n_agents)])
        agent.high_action = high_action
        h_transition = False

    while not terminated:
        ep_len += 1
        state = env.get_obs()
        agent_obs = state[:env.n_agents]
        enemy_obs = state[env.n_agents:]
        avail_actions = env.get_avail_actions()

        if high_action is not None:
            h_transition = False

        action, high_action, low_action = agent.get_action(agent_obs, enemy_obs, avail_actions, high_action=high_action)

        reward, terminated, _ = env.step(action)

        next_killed_enemies = env.death_tracker_enemy
        next_death_tracker = env.death_tracker_enemy

        next_state = env.get_obs()
        n_agent_obs = next_state[:env.n_agents]
        n_enemy_obs = next_state[env.n_agents:]
        # killed_on_current_timestep = next_death_tracker - prev_death_tracker

        high_r = 20 if env.death_tracker_enemy.sum() == env.n_enemies else 0
        reward = (next_killed_enemies[high_action] - prev_killed_enemies[high_action]) * 10 * (1 - env.death_tracker_ally)
        # reward = [reward for _ in range(env.n_agents)]

        agent.push(agent_obs, enemy_obs, high_action, low_action, reward, n_agent_obs, n_enemy_obs, terminated,
                   avail_actions, high_r, h_transition)
        episode_reward += sum(reward)

        if prev_killed_enemies.sum() != next_killed_enemies.sum():
            high_action = None
            h_transition = True

        prev_killed_enemies = next_killed_enemies
        prev_death_tracker = next_death_tracker

    if e % 500 == 0:
        agent.save(curr_dir, e)

    if agent.can_fit():
        agent.fit(e)

    print("EP:{}, R:{}".format(e, episode_reward))

env.close()
