from __future__ import print_function
from __future__ import division

import torch
import numpy as np
import pdb
import argparse
import random
import yaml
import os

import estimators
import utils
import utils_algo
from gridworld import Gridworld
from random_mdps import RandomMDP
# from taxi import Taxi
from chainmdp import ChainMDP
#TaxiPolicy
from utils_gw import collect_data_discrete, GridworldPolicy, compute_Q_values, OPEEvaluator,\
    ToyMDPPolicy, RandomMDPPolicy, random_mdp_Q_vals, generate_one_hots,\
    estimated_Q_values, realizability_measure, bc_measure,\
    td_soln_measure, induced_trans_measure, ortho_vs_qval
    #, random_mdp_krope

def env_setup(FLAGS):
    if FLAGS.env_name == 'Gridworld':
        env = Gridworld(length = 3)
    elif FLAGS.env_name == 'ToyMDP':
        num_states = 7
        num_actions = 4
        gamma = FLAGS.gamma
        env = ToyMDP(num_states, num_actions, init_state = 0)

        # setup MDP
        env.set_transition_prob(0, 0, 1, 1.0)
        env.set_transition_prob(0, 1, 2, 1.0)
        env.set_transition_prob(0, 2, 3, 1.0)
        env.set_transition_prob(0, 3, 4, 1.0)

        env.set_transition_prob(1, 0, 5, 1.0)
        env.set_transition_prob(2, 0, 5, 1.0)
        env.set_transition_prob(3, 0, 5, 1.0)
        env.set_transition_prob(4, 0, 5, 1.0)

        term_prob = 0.1
        env.set_transition_prob(5, 0, 0, 1 -  term_prob)
        env.set_transition_prob(5, 0, 6, term_prob)

        env.set_reward(0, 0, 1)
        env.set_reward(0, 1, 5)
        env.set_reward(0, 2, 5)
        env.set_reward(0, 3, -15)
    elif FLAGS.env_name == 'RandomMDP':
        num_states = 8
        num_actions = 5
        rew_variance = 4.
        env = RandomMDP(num_states, num_actions, reward_variance = rew_variance, use_terminal_state = False)
        env.gamma = FLAGS.gamma
    # elif FLAGS.env_name == 'Taxi':
    #     length = 4
    #     env = Taxi(length)
    #     env.gamma = FLAGS.gamma
    elif 'ChainMDP' in FLAGS.env_name:
        length = 5
        env = ChainMDP(length, stoch_prob = 0.1)
        env.gamma = FLAGS.gamma
    return env

def policies_setup(FLAGS, env):
    if FLAGS.env_name == 'Gridworld':
        #if FLAGS.pi_set == 0:
       # pib = GridworldPolicy(env, typ = 'det-L', L_right_prob = 1)
        #pie = GridworldPolicy(env, typ = 'file', f_name = 'gridworld_policies/p1.txt')
        #pib = GridworldPolicy(env, typ = 'random')
        pie = GridworldPolicy(env, typ = 'random')#GridworldPolicy(env, typ = 'CW')
        pib = GridworldPolicy(env, typ = 'CCW')
            #pib = GridworldPolicy(env, typ = 'file', f_name = 'gridworld_policies/p2.txt')
        # elif FLAGS.pi_set == 1:
        #     pie = GridworldPolicy(env, typ = 'det-L', L_right_prob = 1)
        #     pib = GridworldPolicy(env, typ = 'CW')
        #     #pib = GridworldPolicy(env, typ = 'file', f_name = 'gridworld_policies/p2.txt')
    elif 'ToyMDP' in FLAGS.env_name:
        pie = ToyMDPPolicy(env, pi_num = 0)
        pib = ToyMDPPolicy(env, pi_num = 1)
    elif 'RandomMDP' in FLAGS.env_name:
        pie = RandomMDPPolicy(env, policy_type = 'stochastic')
        pib = RandomMDPPolicy(env, policy_type = 'uniform')
    # elif 'Taxi' in FLAGS.env_name:
    #     alpha = 0.5
    #     pie_matrix = np.load('taxi_policy/pi30.npy')
    #     pib_matrix = np.load('taxi_policy/pi0.npy')
    #     pie = TaxiPolicy(env, pie_matrix)
    #     pib = TaxiPolicy(env, alpha * pib_matrix + (1 - alpha) * pie_matrix)
    elif 'ChainMDP' in FLAGS.env_name:
        pie = RandomMDPPolicy(env, policy_type = 'left')
        pib = RandomMDPPolicy(env, policy_type = 'mix', mix_ratio = FLAGS.mix_ratio, other_pi = pie)
    return pie, pib

def on_policy_estimate(batch_size, truncated_horizon, gamma, mdp, pi = None, random_pi = False):
    # on-policy ground case
    g_pi_paths, _ = collect_data_discrete(mdp, pi, batch_size, truncated_horizon, random_pi = random_pi)
    g_pi_estimator = estimators.OnPolicy(pi, gamma)
    g_pi_est_ret = g_pi_estimator.estimate(g_pi_paths)
    return g_pi_est_ret

def get_phi_stats(data, phi = None, tabular = False, gamma = None):
    if tabular:
        curr_sa = data.curr_state_actions
        next_sa = data.next_state_actions
    if phi is not None:
        curr_sa = phi(curr_sa)
        next_sa = phi(next_sa)
    stats = utils.get_phi_stats(curr_sa, next_sa, gamma)
    return stats

def main(FLAGS):  # noqa
    seed = FLAGS.seed
    utils.set_seed_everywhere(seed)

    mdp, gamma = env_setup(FLAGS), FLAGS.gamma
    print ('gamma for evaluation: {}'.format(gamma))

    ope_method = FLAGS.ope_method
    enc_name = FLAGS.encoder_name

    if enc_name == 'target-phi-sa':
        FLAGS.normalize_states = False

    pie, pib = policies_setup(FLAGS, mdp)

    q_values, v_values, successor_sa = compute_Q_values(FLAGS.env_name, mdp, pie, gamma, epochs = 500)
    print (q_values, v_values)

    pib_q_values, pib_v_values, _ = compute_Q_values(FLAGS.env_name, mdp, pib, gamma, epochs = 500)

    #random_mdp_krope(mdp, pie, gamma, q_values)

    paths, _ = collect_data_discrete(mdp, pib, FLAGS.batch_size, 200)

    # off policy data for OPE
    one_hot_data = generate_one_hots(mdp, paths, pie, q_values)

    ope_data, tr_ope_data, test_ope_data = utils_algo.generate_dataset_objs(FLAGS, one_hot_data, pie, tabular = True)

    pie_est = -1#on_policy_estimate(500, 100, gamma, mdp, pie)
    pib_est = -1#on_policy_estimate(500, 100, gamma, mdp, pib)
    rand_est = -1#on_policy_estimate(100, 100, gamma, mdp, random_pi = True)

    oracle_est_ret = pie_est
    print ('pie true: {} {}'.format(oracle_est_ret , oracle_est_ret * (1. - gamma)))
    print ('pib true: {} {}'.format(pib_est , pib_est * (1. - gamma)))
    rand_est_ret = rand_est
    print ('rand true: {} {}'.format(rand_est_ret , rand_est_ret * (1. - gamma)))
    data_est_ret = -1#np.mean(ope_data.rewards) / (1. - gamma)
    print ('dataset true (ret, rew): {} {}'.format(data_est_ret , data_est_ret * (1. - gamma)))
    ope_data.store_pie_info(oracle_est_ret, None, None, None)

    ope_evaluator = OPEEvaluator(mdp, pie, ope_data.curr_state_actions,\
        ope_data.initial_states, q_values, pib_q_values,\
        ope_data.min_reward, ope_data.max_reward, FLAGS.gamma, successor_sa = successor_sa, rewards = ope_data.rewards,\
        init_sa = ope_data.init_state_actions_curr, sa_visitation = ope_data.sa_visitation)

    # train the encoder
    phi = None
    phi_metrics = {}
    if 'identity' in enc_name:
        phi = None
    else:
        path = '{}_{}_{}_{}_{}'.format(FLAGS.env_name, FLAGS.encoder_name,\
            FLAGS.seed, FLAGS.phi_epochs, FLAGS.phi_outdim)
        if os.path.exists(path):
            phi = torch.load(path)
            print ('loaded pre-trained model')
        else:
            phi, phi_metrics = utils_algo.train_encoder(FLAGS, tr_ope_data, test_ope_data,\
                mdp, gamma, enc_name, pie, tabular = True,\
                q_pie_values = q_values, ope_evaluator = ope_evaluator)
            #torch.save(phi, path)

    #print ('OPE with identity features')
    temp_metrics = {'weights': None}
    #_, temp_metrics = utils_algo.run_experiment_ope(FLAGS, ope_method, ope_data, gamma, mdp, pie, phi = None, enc_name = 'identity', tabular = True)

    #utils.plot_covariance_heatmap(ope_data, pie, gamma, phi, enc_name = FLAGS.encoder_name, tabular = True)

    #utils.plot_feature_matrix(ope_data, pie = pie, phi = phi, fname = enc_name, ground_weights = temp_metrics['weights'])

    #all_q = q_values.reshape(-1, 1)[np.argmax(ope_data.curr_state_actions, axis = 1)]
    #utils.high_d_plot(ope_data, pie = pie, phi = phi, fname = enc_name, typ = FLAGS.visual_type, q_values = all_q)

    realize_stats = realizability_measure(mdp, q_values, phi)
    #bc_stats = {}#bc_measure(ope_data, q_values, phi, mdp.observation_space.shape[0] * mdp.action_space.shape[0] if FLAGS.encoder_name == 'identity' else FLAGS.phi_outdim, pie, FLAGS.gamma)
    
    bc_stats = utils.bc_measure(phi, mdp.observation_space.shape[0] * mdp.action_space.shape[0] if FLAGS.encoder_name == 'identity' else FLAGS.phi_outdim,\
        tr_ope_data, gamma)

    induced_trans_error = induced_trans_measure(mdp, pie, phi, ope_data.curr_state_actions, FLAGS.phi_outdim)
    ortho_qval_stat = ortho_vs_qval(mdp, q_values, phi, successor_sa)

    print (f'OPE with {enc_name} features')
    r_est, ope_metrics = utils_algo.run_experiment_ope(FLAGS, ope_method, ope_data, gamma, mdp, pie, phi = phi, enc_name = enc_name, tabular = True)

    td_solution_stats = td_soln_measure(mdp, phi, realize_stats['realizable_w'], ope_metrics['weights'])
    
    final_ope_err = ope_evaluator.evaluate(phi, ope_metrics['weights'])
    
    estimated_Q_values(FLAGS.env_name, mdp, q_values, phi, temp_metrics['weights'], ope_metrics['weights'])

    final_phi_stats = ope_metrics['phi_lspe_stats']

    # 1. ideal solution assuming true labels (q values) were known
    # 2. norm of the solution in the ideal case
    final_phi_stats.update(realize_stats)

    # 1. norm distance between projected fixed point and true fixed point
    # same as above, but scaled by Q values
    final_phi_stats.update(bc_stats)

    # 1. norm between TD's final solution and solution by linear regression
    # 2. norm between TD's final weight vector and weight vector by linear regression
    final_phi_stats.update(td_solution_stats)

    # 1. distance between eigenvalues of induced transition matrix and original trans matrix
    final_phi_stats['induced_trans_error'] = induced_trans_error

    # 1. orthogonality between each unique (s,a) and every other
    # 2. correlation coefficient between orthogonality between two state-actions and their q value diff
    final_phi_stats.update(ortho_qval_stat)

    # main stats
    # phi_stats: final stats on full dataset
    # phi_metrics: stats on mini batch during training
    # init_phi_stats: stats on all init states
    # ope_metrics: ope related stats (nothing about phi, for now)

    algo_name = '{}_{}'.format(FLAGS.ope_method, FLAGS.encoder_name)
    summary = {
        'env': FLAGS.env_name,
        'results': {
            FLAGS.dataset_name: {} # no name for the dataset
        },
        'encoder_name': FLAGS.encoder_name,
        'seed': seed,
        'hp': {
            'Q_lr': FLAGS.Q_lr,
            'phi_lr': FLAGS.phi_lr,
            'phi_outdim': FLAGS.phi_outdim,
            'beta': FLAGS.beta,
            'M_lr': FLAGS.M_lr,
            'phi_hard_update_freq': FLAGS.phi_hard_update_freq,
            'mix_ratio': FLAGS.mix_ratio,
            'batch_size': FLAGS.batch_size
        },
        # 'oracle_est': oracle_est_ret,
        # 'rand_est': rand_est_ret,
        'gamma': FLAGS.gamma,
        'normalize_states': FLAGS.normalize_states,
        'normalize_rewards': FLAGS.normalize_rewards
    }

    summary['results'][FLAGS.dataset_name][algo_name] = final_phi_stats
    summary['results'][FLAGS.dataset_name][algo_name]['ope_error'] = final_ope_err
    summary['results'][FLAGS.dataset_name][algo_name]['bellman_error'] = ope_metrics['bellman_residual'][FLAGS.epochs]
    summary['results'][FLAGS.dataset_name][algo_name]['tr_losses'] = phi_metrics['tr_losses'][FLAGS.phi_epochs] if 'tr_losses' in phi_metrics else 0
    summary['results'][FLAGS.dataset_name][algo_name]['phi_tr_metrics'] = phi_metrics
    print (summary)
    np.save(FLAGS.outfile, summary)
