import matplotlib
import matplotlib.pyplot as plt
from itertools import count
import torch.optim as optim
import torch
import math
import numpy as np
from environments.gridworld import GridworldEnv
from utils import plot_rewards, plot_durations, plot_state, get_screen

from .memory_replay import ReplayMemory, Transition
from .network import DQN, select_action, optimize_model, optimize_policy, PolicyNetwork

def trainD(file_name="Distral_2col", list_of_envs=[GridworldEnv(5),
            GridworldEnv(4), GridworldEnv(6)], batch_size=128, gamma=0.999,
            c_ent=0.1, c_kl=0.1,
            eps_start=0.9, eps_end=0.05, eps_decay=5,
            is_plot=False, num_episodes=2000,
            max_num_steps_per_episode=1000, learning_rate=1e-3,
            memory_replay_size=1000000, memory_policy_size=1000):
    """
    Soft Q-learning training routine. Retuns rewards and durations logs.
    Plot environment screen
    """
    num_actions = list_of_envs[0].action_space.n
    input_size = list_of_envs[0].observation_space.shape[0]
    num_envs = len(list_of_envs)
    policy = PolicyNetwork(input_size, num_actions)
    models = [DQN(input_size,num_actions) for _ in range(0, num_envs)]   ### Add torch.nn.ModuleList (?)
    memories = [ReplayMemory(memory_replay_size, memory_policy_size) for _ in range(0, num_envs)]

    c_ent = c_ent if isinstance(c_ent, list) else [c_ent] * num_envs
    c_kl = c_kl if isinstance(c_kl, list) else [c_kl] * num_envs
    alpha = [a / (a + b) for a, b in zip(c_ent, c_kl)]
    beta = [1.0 / (a + b) for a, b in zip(c_ent, c_kl)]

    use_cuda = torch.cuda.is_available()
    if use_cuda:
        policy.cuda()
        for model in models:
            model.cuda()

    optimizers = [optim.Adam(model.parameters(), lr=learning_rate)
                    for model in models]
    policy_optimizer = optim.Adam(policy.parameters(), lr=learning_rate)
    # optimizer = optim.RMSprop(model.parameters(), )

    episode_durations = [[] for _ in range(num_envs)]
    episode_rewards = [[] for _ in range(num_envs)]

    steps_done = np.zeros(num_envs, dtype=int)
    episodes_done = np.zeros(num_envs, dtype=int)
    current_time = np.zeros(num_envs, dtype=int)

    # Initialize environments
    states = []
    for env in list_of_envs:
        states.append(torch.from_numpy( env.reset() ).type(torch.FloatTensor).view(-1,input_size))

    while np.min(episodes_done) < num_episodes:
        # TODO: add max_num_steps_per_episode

        # Optimization is given by alterating minimization scheme:
        #   1. do the step for each env
        #   2. do one optimization step for each env using "soft-q-learning".
        #   3. do one optimization step for the policy

        for i_env, env in enumerate(list_of_envs):
            # print("Cur episode:", i_episode, "steps done:", steps_done,
            #         "exploration factor:", eps_end + (eps_start - eps_end) * \
            #         math.exp(-1. * steps_done / eps_decay))
        
            # Select and perform an action
            #print(states[i_env])

            action = select_action(states[i_env], policy, models[i_env], num_actions,
                                    eps_start, eps_end, eps_decay,
                                    episodes_done[i_env], alpha[i_env], beta[i_env])

            steps_done[i_env] += 1
            current_time[i_env] += 1
            next_state_tmp, reward, done, _ = env.step(action[0,0])
            reward = torch.FloatTensor([reward])

            # Observe new state
            next_state = torch.from_numpy( next_state_tmp ).type(torch.FloatTensor).view(-1,input_size)

            if done:
                next_state = None

            # Store the transition in memory
            time = torch.FloatTensor([current_time[i_env]])
            memories[i_env].push(states[i_env], action, next_state, reward, time)

            # Perform one step of the optimization (on the target network)
            if len(memories[i_env]) >= batch_size:
                optimize_model(policy, models[i_env], optimizers[i_env],
                                memories[i_env], batch_size, alpha[i_env], beta[i_env], gamma)

            # Update state
            states[i_env] = next_state

            # Check if agent reached target
            if done:
                print("ENV: {}, iter: {}\treward: {:.4f}\tit:{}\texp_factor:{:.6f}".format(
                    i_env, episodes_done[i_env], env.episode_total_reward, current_time[i_env],
                    eps_end + (eps_start - eps_end) * math.exp(-1. * episodes_done[i_env] / eps_decay)
                ))
                states[i_env] = torch.from_numpy( env.reset() ).type(torch.FloatTensor).view(-1,input_size)
                episodes_done[i_env] += 1
                episode_durations[i_env].append(current_time[i_env])
                current_time[i_env] = 0
                episode_rewards[i_env].append(env.episode_total_reward)
                if is_plot:
                    plot_rewards(episode_rewards, i_env)


        optimize_policy(policy, policy_optimizer, memories, batch_size,
                    num_envs, gamma)

    print('Complete')
    env.close()

    ## Store Results

    np.save(file_name + '-distral-2col-rewards', episode_rewards)
    np.save(file_name + '-distral-2col-durations', episode_durations)

    return models, policy, episode_rewards, episode_durations
