from __future__ import print_function
from __future__ import division

import torch
import numpy as np
import pdb
import argparse
import random
import yaml
import os

import estimators
import utils
import utils_algo
from toymdp import ModifiedRoy, Bairds
from utils_gw import collect_data_discrete, RoyMDPPolicy, BairdsPolicy, generate_bairds_features, generate_roy_features, OPEEvaluator,\
    estimated_Q_values, realizability_measure, bc_measure, td_soln_measure, induced_trans_measure, ortho_vs_qval

def env_setup(FLAGS):
    if FLAGS.env_name == 'Roy':
        env = ModifiedRoy()
    elif FLAGS.env_name == 'Bairds':
        env = Bairds()
    return env

def policies_setup(FLAGS, env):
    if FLAGS.env_name == 'Roy' or FLAGS.env_name == 'Bairds':
        pie = RoyMDPPolicy(env)
        pib = RoyMDPPolicy(env) 
    return pie, pib

def main(FLAGS):  # noqa
    seed = FLAGS.seed
    utils.set_seed_everywhere(seed)

    mdp, gamma = env_setup(FLAGS), FLAGS.gamma
    print ('gamma for evaluation: {}'.format(gamma))

    ope_method = FLAGS.ope_method
    enc_name = FLAGS.encoder_name

    pie, pib = policies_setup(FLAGS, mdp)

    if FLAGS.env_name == 'Roy':
        q_values = np.array([[0.8], [1.0], [0], [0], [0]])
    elif FLAGS.env_name == 'Bairds':
        q_values = np.array([0 for _ in range(mdp.num_states * mdp.num_actions)]) 
        q_values[:mdp.num_states // 2] = 1.

    if FLAGS.pw_dataset:
        num_trans = 2000
        on_policy_paths, _ = collect_data_discrete(mdp, pib, 5000, 10)
        on_policy_data = mdp.parse_dataset(on_policy_paths, num_trans = num_trans)
        a1, p1 = mdp.count(on_policy_data)
        dataset, other_dataset = mdp.generate_pairwise_dataset(p1, num_trans = num_trans)

        dataset, other_dataset = mdp.add_bad_transitions(p1, dataset, other_dataset,\
            sample_num = FLAGS.roy_sample_num,
            mix_ratio = FLAGS.mix_ratio)

        # contamination process

        features_data = generate_roy_features(mdp, dataset, pie, q_values)
        other_features_data = generate_roy_features(mdp, other_dataset, pie, q_values)

        features_data = mdp.merge_feature_datasets(features_data, other_features_data)
    else:
        # original code
        num_trans = 2000
        off_policy_data = mdp.get_dataset(num_trans, off_type = FLAGS.roy_off_type)
        on_policy_paths, _ = collect_data_discrete(mdp, pib, 3000, 10)
        on_policy_data = mdp.parse_dataset(on_policy_paths, num_trans = num_trans)
        dataset = mdp.merge_data(on_policy_data, off_policy_data, mix_ratio = FLAGS.mix_ratio)

        a1, p1 = mdp.count(on_policy_data)
        mdp.count(off_policy_data)
        a2, p2 = mdp.count(dataset)
        # mdp.ratios(a1, a2)
        # mdp.ratios(p1, p2)

        print (off_policy_data['states'].shape, on_policy_data['states'].shape)
        features_data = generate_roy_features(mdp, dataset, pie, q_values)

    ope_data, tr_ope_data, test_ope_data = utils_algo.generate_dataset_objs(FLAGS, features_data, pie, tabular = True)

    pie_est = -1#on_policy_estimate(500, 100, gamma, mdp, pie)
    pib_est = -1#on_policy_estimate(500, 100, gamma, mdp, pib)
    rand_est = -1#on_policy_estimate(100, 100, gamma, mdp, random_pi = True)

    oracle_est_ret = pie_est
    print ('pie true: {} {}'.format(oracle_est_ret , oracle_est_ret * (1. - gamma)))
    print ('pib true: {} {}'.format(pib_est , pib_est * (1. - gamma)))
    rand_est_ret = rand_est
    print ('rand true: {} {}'.format(rand_est_ret , rand_est_ret * (1. - gamma)))
    data_est_ret = -1#np.mean(ope_data.rewards) / (1. - gamma)
    print ('dataset true (ret, rew): {} {}'.format(data_est_ret , data_est_ret * (1. - gamma)))
    ope_data.store_pie_info(oracle_est_ret, None, None, None)

    ope_evaluator = OPEEvaluator(mdp, pie, ope_data.curr_state_actions,\
        ope_data.initial_states, q_values, None,\
        ope_data.min_reward, ope_data.max_reward, FLAGS.gamma,\
        raw_state = ope_data.curr_raw_state)

    # train the encoder
    phi = None
    phi_metrics = {}
    if 'identity' in enc_name:
        phi = None
    else:
        path = '{}_{}_{}_{}_{}'.format(FLAGS.env_name, FLAGS.encoder_name,\
            FLAGS.seed, FLAGS.phi_epochs, FLAGS.phi_outdim)
        if os.path.exists(path):
            phi = torch.load(path)
            print ('loaded pre-trained model')
        else:
            phi, phi_metrics = utils_algo.train_encoder(FLAGS, tr_ope_data, test_ope_data,\
                mdp, gamma, enc_name, pie, tabular = True, q_pie_values = q_values,
                ope_evaluator = ope_evaluator)
            #torch.save(phi, path)

    #print ('OPE with identity features')
    temp_metrics = {'weights': None}
    #_, temp_metrics = utils_algo.run_experiment_ope(FLAGS, ope_method, ope_data, gamma, mdp, pie, phi = None, enc_name = 'identity', tabular = True)

    #utils.plot_covariance_heatmap(ope_data, pie, gamma, phi, enc_name = FLAGS.encoder_name, tabular = True)

    #utils.plot_feature_matrix(ope_data, pie = pie, phi = phi, fname = enc_name, ground_weights = temp_metrics['weights'])

    #utils.high_d_plot(ope_data, pie = pie, phi = phi, fname = enc_name, typ = FLAGS.visual_type, ground_weights = temp_metrics['weights'])

    realize_stats = {}#realizability_measure(mdp, q_values, phi)
    bc_stats = {}#bc_measure(ope_data, q_values, phi, mdp.observation_space.shape[0] * mdp.action_space.shape[0] if FLAGS.encoder_name == 'identity' else FLAGS.phi_outdim, pie, FLAGS.gamma)
    induced_trans_error = {}#induced_trans_measure(mdp, pie, phi, ope_data.curr_state_actions, FLAGS.phi_outdim)
    ortho_qval_stat = {}#ortho_vs_qval(mdp, q_values, phi, successor_sa)

    print (f'OPE with {enc_name} features')
    r_est, ope_metrics = utils_algo.run_experiment_ope(FLAGS, ope_method, ope_data, gamma, mdp, pie, phi = phi, enc_name = enc_name, tabular = True)

    td_solution_stats = {}#td_soln_measure(mdp, phi, realize_stats['realizable_w'], ope_metrics['weights'])
    
    final_ope_err = ope_evaluator.evaluate(phi, ope_metrics['weights'])
    
    #estimated_Q_values(FLAGS.env_name, mdp, q_values, phi, temp_metrics['weights'], ope_metrics['weights'])

    final_phi_stats = ope_metrics['phi_lspe_stats']

    # 1. ideal solution assuming true labels (q values) were known
    # 2. norm of the solution in the ideal case
    final_phi_stats.update(realize_stats)

    # 1. norm distance between projected fixed point and true fixed point
    # same as above, but scaled by Q values
    final_phi_stats.update(bc_stats)

    # 1. norm between TD's final solution and solution by linear regression
    # 2. norm between TD's final weight vector and weight vector by linear regression
    final_phi_stats.update(td_solution_stats)

    # 1. distance between eigenvalues of induced transition matrix and original trans matrix
    final_phi_stats['induced_trans_error'] = induced_trans_error

    # 1. orthogonality between each unique (s,a) and every other
    # 2. correlation coefficient between orthogonality between two state-actions and their q value diff
    final_phi_stats.update(ortho_qval_stat)

    # main stats
    # phi_stats: final stats on full dataset
    # phi_metrics: stats on mini batch during training
    # init_phi_stats: stats on all init states
    # ope_metrics: ope related stats (nothing about phi, for now)

    algo_name = '{}_{}'.format(FLAGS.ope_method, FLAGS.encoder_name)
    summary = {
        'env': FLAGS.env_name,
        'results': {
            FLAGS.dataset_name: {} # no name for the dataset
        },
        'encoder_name': FLAGS.encoder_name,
        'seed': seed,
        'hp': {
            'Q_lr': FLAGS.Q_lr,
            'phi_lr': FLAGS.phi_lr,
            'phi_outdim': FLAGS.phi_outdim,
            'beta': FLAGS.beta,
            'M_lr': FLAGS.M_lr,
            'phi_hard_update_freq': FLAGS.phi_hard_update_freq,
            'mix_ratio': FLAGS.mix_ratio,
            'aux_alpha': FLAGS.aux_alpha
        },
        # 'oracle_est': oracle_est_ret,
        # 'rand_est': rand_est_ret,
        'gamma': FLAGS.gamma,
        'normalize_states': FLAGS.normalize_states,
        'normalize_rewards': FLAGS.normalize_rewards
    }

    summary['results'][FLAGS.dataset_name][algo_name] = final_phi_stats
    summary['results'][FLAGS.dataset_name][algo_name]['ope_error'] = final_ope_err
    summary['results'][FLAGS.dataset_name][algo_name]['bellman_error'] = ope_metrics['bellman_residual'][FLAGS.epochs]
    summary['results'][FLAGS.dataset_name][algo_name]['tr_losses'] = phi_metrics['tr_losses'][FLAGS.phi_epochs] if 'tr_losses' in phi_metrics else 0
    summary['results'][FLAGS.dataset_name][algo_name]['phi_tr_metrics'] = phi_metrics
    #print (summary)
    np.save(FLAGS.outfile, summary) 