import os
import time
import torch
import SMOS

import numpy as np
import core.ctree.cytree as cytree

from tqdm.auto import tqdm
from torch.cuda.amp import autocast as autocast
from core.mcts import MCTS
from core.game import GameHistory
from core.utils import select_action, prepare_observation_lst

from core.storage_config import StorageConfig
from core.shared_storage import get_shared_storage, read_weights
from core.meta_data_manager import MetaDataSharedMemoryManager


def _test(config, shared_storage, final_test=False, seeds=None, test_label=None, use_pb=False, test_episodes=None, smos_client=None, storage_config=None):
    test_model = config.get_uniform_network(is_data_worker=True)
    meta_data_manager = MetaDataSharedMemoryManager(config, storage_config)
    best_test_score = float('-inf')
    episodes = 0
    while True:
        counter = shared_storage.get_counter()
        if (counter >= config.training_steps + config.last_steps) or (final_test and counter > 0):
            time.sleep(30)
            break
        if (counter >= config.test_interval * episodes) or (final_test and counter == 0):
            if episodes == 0:
                time.sleep(10)
            episodes += 1
            # test_model.set_weights(shared_storage.get_weights()[0])
            test_model.set_weights(read_weights(meta_data_manager, smos_client, shared_storage, storage_config)[0])
            test_model.eval()

            test_score, _ = test(config, test_model, counter, test_episodes or config.test_episodes, config.device, False, final_test=final_test, use_pb=use_pb, save_video=False, seeds=seeds)
            mean_score = test_score.mean()
            std_score = test_score.std()
            print('Start evaluation at step {}.'.format(counter))
            if mean_score >= best_test_score:
                best_test_score = mean_score
                torch.save(test_model.state_dict(), config.model_path)

            test_log = {
                'mean_score': mean_score,
                'std_score': std_score,
                'max_score': test_score.max(),
                'min_score': test_score.min(),
                'all_score': test_score,
            }

            shared_storage.add_test_log(counter, test_log, test_label)
            print('Step {}, test scores: \n{}'.format(counter, test_score))

        time.sleep(30)


def start_test(config, storage_config: StorageConfig, final_test=False, seeds=None, test_label=None, use_pb=False, test_episodes=None, test_visible_device=0):
    """
    Start test process used in training (i.e. _test). Call this method remotely
    """
    # get storage
    shared_storage = get_shared_storage(storage_config=storage_config)
    # smos client
    smos_client = SMOS.Client(connection=storage_config.smos_connection)
    # start _test
    # os.environ['CUDA_VISIBLE_DEVICES'] = str(test_visible_device)
    config.device = torch.device(f"cuda:{test_visible_device}")
    print(f"[Test] Start test function at process {os.getpid()} on gpu {test_visible_device}")
    _test(config=config, shared_storage=shared_storage, final_test=final_test, seeds=seeds, test_label=test_label, use_pb=use_pb, test_episodes=test_episodes, smos_client=smos_client, storage_config=storage_config)


def test(config, model, counter, test_episodes, device, render, save_video=False, final_test=False, use_pb=False, seeds=None):
    """evaluation test
    Parameters
    ----------
    model: any
        models for evaluation
    counter: int
        current training step counter
    test_episodes: int
        number of test episodes
    device: str
        'cuda' or 'cpu'
    render: bool
        True -> render the image during evaluation
    save_video: bool
        True -> save the videos during evaluation
    final_test: bool
        True -> this test is the final test, and the max moves would be 108k/skip
    use_pb: bool
        True -> use tqdm bars
    """
    model.to(device)
    model.eval()
    save_path = os.path.join(config.exp_path, 'recordings', 'step_{}'.format(counter))

    if use_pb:
        pass
        # pb = tqdm(np.arange(config.max_moves), leave=True)

    if seeds is None:
        seeds = np.arange(1, test_episodes + 1)
    assert (len(seeds) == test_episodes), (seeds)

    seeds = [int(x) for x in seeds]

    with torch.no_grad():
        # new games
        envs = [config.new_game(seed=seeds[i], save_video=save_video, save_path=save_path, test=True, final_test=final_test,
                              video_callable=lambda episode_id: True, uid=i) for i in range(test_episodes)]
        # initializations
        init_obses = [env.reset() for env in envs]
        dones = np.array([False for _ in range(test_episodes)])
        game_histories = [
            GameHistory(max_length=config.max_moves, config=config) for
            _ in
            range(test_episodes)]
        for i in range(test_episodes):
            game_histories[i].init([init_obses[i] for _ in range(config.stacked_observations)])

        step = 0
        ep_ori_rewards = np.zeros(test_episodes)
        ep_clip_rewards = np.zeros(test_episodes)
        # loop
        while not dones.all():
            if render:
                for i in range(test_episodes):
                    envs[i].render()

            if config.image_based:
                stack_obs = []
                for game_history in game_histories:
                    stack_obs.append(game_history.step_obs())
                stack_obs = prepare_observation_lst(stack_obs)
                stack_obs = torch.from_numpy(stack_obs).to(device).float() / 255.0
            else:
                stack_obs = [game_history.step_obs() for game_history in game_histories]
                stack_obs = torch.from_numpy(np.array(stack_obs)).to(device)

            with autocast():
                network_output = model.initial_inference(stack_obs.float(), to_numpy=False)
            hidden_state_roots = network_output.hidden_state
            reward_hidden_roots = network_output.reward_hidden
            value_prefix_pool = network_output.value_prefix
            policy_logits_pool = network_output.policy_logits
            policy_logits_pool = policy_logits_pool - policy_logits_pool.max(axis=-1).reshape(-1, 1)
            policy_logits_pool = np.exp(policy_logits_pool)
            policy_logits_pool = policy_logits_pool / policy_logits_pool.sum(axis=-1).reshape(-1, 1)
            policy_logits_pool = policy_logits_pool.tolist()

            roots = cytree.Roots(test_episodes, config.action_space_size, config.num_simulations)
            roots.prepare_no_noise(value_prefix_pool, policy_logits_pool)
            # do MCTS for a policy (argmax in testing)
            MCTS(config, config.num_simulations).gpu_search(roots, model, hidden_state_roots, reward_hidden_roots, to_numpy=False)

            roots_distributions = roots.get_distributions()
            roots_values = roots.get_values()
            if step % 10 == 0:
                print(f"[test] step {step}")
            for i in range(test_episodes):
                if dones[i]:
                    continue

                distributions, value, env = roots_distributions[i], roots_values[i], envs[i]
                # select the argmax, not sampling
                action, _ = select_action(distributions, temperature=1, deterministic=True)

                obs, ori_reward, done, info = env.step(action)
                if config.clip_reward:
                    clip_reward = np.sign(ori_reward)
                else:
                    clip_reward = ori_reward

                game_histories[i].store_search_stats(distributions, value)
                game_histories[i].append(action, obs, clip_reward)

                dones[i] = done
                ep_ori_rewards[i] += ori_reward
                ep_clip_rewards[i] += clip_reward

            step += 1
            if use_pb:
                if step % 10 == 0:
                    print('{} In step {}, scores: {}(max: {}, min: {}) currently, alive: {}/{}.'
                                   ''.format(config.env_name, counter,
                                          ep_ori_rewards.mean(), ep_ori_rewards.max(), ep_ori_rewards.min(), test_episodes - dones.astype(np.int32).sum(), test_episodes), flush=True)

        for env in envs:
            env.close()

    return ep_ori_rewards, save_path
