import argparse
import random
import os
import warnings
from typing import Dict, Callable, Iterable
from functools import partial

import yaml
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

from environment_generator import EnvironmentDataset
from train_models import seed_everything
from model.reward_model import model_to_reward_function, load_all_saved_models
from model.head import head_state_to_tensor
from policy.policy import RegretMatchingPolicy, EpsGreedyPolicy
from policy.policy_iterator_callbacks import ReturnGraphCallback
from policy.policy_iteration import PolicyIteration, K_Of_N_PolicyIteration
from config import Config

import sys
# pytorch module loading workaround (pickle limitation)
sys.path.append('../model')

def obtain_policy_and_save(
        k: int,
        n: int,
        seed: int = None,
        test: bool = False,
        use_head_models: bool = False,
        use_true_reward: bool = False,
        sampled: bool = False,
        averaged: bool = False,
):
    """
    Obtains the k-of-N policy and saves everything
    where it should be.
    :param test: If we're using the test environment or not.
    :param use_true_reward: Whether or not to use the true reward.
    :param seed: What seed to use. This will only impact the ordering of sampled
    models in the case that config has `resample_models` set to True.
    :param test: Obtain the policy for the test environment.
    :param use_head_models: DEPRECATED
    :param use_true_reward: Use the true reward function for the envrionment instead of the
    reward function approximators.
    :param sampled: Obtain the "sampled" policy iteration policy. That is, for every iteration,
    sample a reward function and do a loop of policy iteration with respect to the sampled
    reward functon.
    :param averaged: Obtain the "averaged" policy iteration policy. That is, obtain the "averaged"
    reward function approximator and use this as the reward function. Do policy iteration on this
    reward function.
    """
    if seed is not None:
        seed_everything(seed)
    config = Config()
    env = (EnvironmentDataset.obtain_test_env()
           if test else EnvironmentDataset.obtain_train_env())
    prob_trans_mat = np.array(env.prob_trans_mat)
    fig_return = plt.figure()
    ax_return = fig_return.add_subplot(1, 1, 1)
    fig_return_kofn = plt.figure()
    ax_return_kofn = fig_return_kofn.add_subplot(1, 1, 1)

    callbacks = [
        ReturnGraphCallback(
            f"Expected Return Over Time On True Reward Function (Seed {seed})",
            "Iteration",
            "Return",
            ax_return,
            env.true_reward_callable,
        ),
        ReturnGraphCallback(
            f"Expected Return Over Time On k-of-N Reward Function (Seed {seed})",
            "Iteration",
            "Return",
            ax_return_kofn,
        ),
        ]
    if averaged:
        reward_functions = [model_to_reward_function(
            torch.load(os.path.join(config.models_dir, x)).eval(), env)
            for x in sorted(os.listdir(config.models_dir))]
        memo = {}
        def reward_function(state_index, action):
            val = 0
            if (state_index, action) in memo:
                return memo[(state_index, action)]
            for fcn in reward_functions:
                val += fcn(state_index, action)
            memo[(state_index, action)] = val / len(reward_functions)
            return memo[(state_index, action)]
        print('creating averaged policy...')

        policy_iterator = PolicyIteration(
            EpsGreedyPolicy(
                prob_trans_mat.shape[0],
                prob_trans_mat.shape[1],
                epsilon=0,
            ),
            prob_trans_mat,
            reward_function,
            after_iteration_callbacks=callbacks,
        )
    elif sampled:
        print('creating sampled policies...')
        reward_functions = [model_to_reward_function(
            torch.load(os.path.join(config.models_dir, x)).eval(), env)
            for x in sorted(os.listdir(config.models_dir))]
        for i, reward_function in tqdm(enumerate(reward_functions), total=len(reward_functions)):
            if f'{i}.npy' in os.listdir(f'{config.policies_dir}/sampled'):
                print('skipping', i)
                continue
            policy_iterator = PolicyIteration(
                EpsGreedyPolicy(
                    prob_trans_mat.shape[0],
                    prob_trans_mat.shape[1],
                    epsilon=0,
                ),
                prob_trans_mat,
                reward_function,
                after_iteration_callbacks=callbacks,
            )
            [policy_iterator() for _ in tqdm(range(config.num_iterations))]
            policy_iterator.cleanup()
            np.save(f'{config.policies_dir}/sampled/{i}.npy', policy_iterator.policy.pi)

    else:
        policy_iterator = K_Of_N_PolicyIteration(
            k,
            n,
            RegretMatchingPolicy(
                prob_trans_mat.shape[0],
                prob_trans_mat.shape[1]
            ),
            prob_trans_mat,
            config.heads_dir if use_head_models else config.models_dir,
            ((lambda x: env.true_reward_callable) if use_true_reward
            else partial(model_to_reward_function, env=env)),
            gamma=config.gamma,
            theta=config.theta,
            after_iteration_callbacks=callbacks,
            resample_models=config.resample_models
        )
        print(f'creating {k}-of-{n} policy')

    if not sampled:
        [policy_iterator() for _ in tqdm(range(config.num_iterations))]
        policy_iterator.cleanup()

        policies_dir = config.head_policies_dir if args.head else config.policies_dir
        figs_dir = config.head_figs_dir if args.head else config.model_figs_dir
        policy_type = ('true_reward' if use_true_reward else f'{k}_of_{n}') \
            + ("" if test else "_train")
        if sampled:
            np.save(f'{config.policies_dir}/sampled_{seed}.npy', policy_iterator.policy.pi)
        elif averaged:
            np.save(f'{config.policies_dir}/averaged.npy', policy_iterator.policy.pi)
        else:
            np.save(f'{config.policies_dir}/{policy_type}_{seed}.npy', policy_iterator.policy.pi)
        fig_return.savefig(f'{config.model_figs_dir}/policy/{policy_type}_returns_true_seed_{seed}.png')
        fig_return_kofn.savefig(f'{config.model_figs_dir}/policy/{policy_type}_returns_seed_{seed}.png')


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--head', type=bool, default=False)
    parser.add_argument('--test', type=bool, default=False)
    parser.add_argument('--use-true-reward', type=bool, default=False)
    parser.add_argument('--k', type=int, default=None)
    parser.add_argument('--n', type=int, default=None)
    parser.add_argument('--seed', type=int, default=None)
    parser.add_argument('--sampled', type=bool, default=False)
    parser.add_argument('--averaged', type=bool, default=False)
    args = parser.parse_args()
    config = Config()
    if args.k:
        config.k = [args.k]
    if args.n:
        config.n = args.n

    state_to_tensor = None
    if args.head:
        models = load_all_saved_models(amount_to_load=config.head_model_seed_amount)
        state_to_tensor = head_state_to_tensor(models)

    if args.use_true_reward:
        config.k, config.n = [1], 1

    for k in config.k:
        obtain_policy_and_save(
            k,
            config.n,
            args.seed,
            args.test,
            args.head,
            args.use_true_reward,
            args.sampled,
            args.averaged,
        )
        if args.sampled or args.averaged:
            break
