from __future__ import print_function
from __future__ import division

import torch
import numpy as np
import pdb
import random

import gymnasium as g
import estimators
import utils
import utils_algo

def env_setup(FLAGS):
    env = utils.load_env(FLAGS.env_name, FLAGS.gamma, env_type = FLAGS.env_type)
    return env

def policies_setup(FLAGS, config, env, pi_num):
    pi = utils.load_policies(config, env, FLAGS.env_name, pi_num, env_type = FLAGS.env_type)
    return pi

def get_dataset_stats(dataset, pie, gamma, phi = None):

    curr_states = dataset.curr_states
    curr_actions = dataset.curr_actions
    curr_sa = np.concatenate((curr_states, curr_actions), axis = 1)

    next_states = dataset.next_states
    pie_next_actions = pie.batch_sample(dataset.unnormalize_states(next_states))
    next_sa = np.concatenate((next_states, pie_next_actions), axis = 1)

    if phi is not None:
        phi_curr_sa = phi(curr_sa)
        phi_next_sa = phi(next_sa)
        stats = utils.get_phi_stats(phi_curr_sa, phi_next_sa, gamma)
        cap_stats = utils.get_phi_capacity_stats(curr_sa, pie, phi_curr_sa)
    else:
        stats = utils.get_phi_stats(curr_sa, next_sa, gamma)
        cap_stats = utils.get_phi_capacity_stats(curr_sa, pie)
    stats = stats | cap_stats
    return stats

def main(FLAGS, config):  # noqa
    seed = FLAGS.seed
    utils.set_seed_everywhere(seed)

    mdp, gamma = env_setup(FLAGS), FLAGS.gamma
    print ('gamma for evaluation: {}'.format(gamma))

    ope_method = FLAGS.ope_method
    enc_name = FLAGS.encoder_name

    if enc_name == 'target-phi-sa':
        FLAGS.normalize_states = False

    directory = 'datasets'
    dataset_name_part = '{}-{}'.format('d4rl' if FLAGS.env_type == 'd4rl' else 'custom', FLAGS.dataset_name)
    dataset_name = '{}_name_{}'.format(FLAGS.env_name, dataset_name_part)
    dataset_info = np.load('{}/{}.npy'.format(directory, dataset_name), allow_pickle = True).item()
    if FLAGS.env_type == 'd4rl':
        pie = utils.load_d4rl_policy(config, mdp, FLAGS.env_name, FLAGS.pie_num)
    else:
        pie = policies_setup(FLAGS, config, mdp, pi_num = FLAGS.pie_num)

    ope_data, tr_ope_data, test_ope_data = utils_algo.generate_dataset_objs(FLAGS, dataset_info, pie)

    pie_directory = 'pie_configs'
    pie_config_file = '{}_{}_{}'.format(FLAGS.env_name, FLAGS.gamma, FLAGS.pie_num)
    pie_info = np.load('{}/{}.npy'.format(pie_directory, pie_config_file), allow_pickle = True).item()
    oracle_est_ret = pie_info['pie_avg_est_ret']
    print ('pie true: {}'.format(oracle_est_ret))
    rand_est_ret = pie_info['rand_avg_est_ret']
    print ('rand true: {}'.format(rand_est_ret))

    pie_path_sa_vals = pie_info['pie_path_sa_vals']
    pie_path_states = pie_info['pie_path_states']
    pie_path_acts = pie_info['pie_path_acts']
    #ope_data.store_pie_info(oracle_est_ret, pie_path_sa_vals, pie_path_states, pie_path_acts)

    ope_evaluator = utils.DeepOPEEvaluator(mdp, pie, 
        ope_data, oracle_est_ret, rand_est_ret,\
        ope_data.min_reward, ope_data.max_reward, FLAGS.gamma)

    #utils.plot_covariance_heatmap(ope_data, pie, gamma, enc_name = 'identity')

    # train the encoder
    phi = None
    phi_metrics = {}
    if not FLAGS.aux_task:
        if 'identity' in enc_name:
            phi = None
        else:
            phi, phi_metrics, critic = utils_algo.train_encoder(FLAGS, tr_ope_data, test_ope_data,\
                mdp, gamma, enc_name, pie, ope_evaluator = ope_evaluator)

    #utils.plot_covariance_heatmap(ope_data, pie, gamma, phi, enc_name = enc_name)

    #from sklearn.metrics.pairwise import cosine_similarity
    #csa = np.concatenate((ope_data.curr_states, ope_data.curr_actions), axis = 1)
    # if phi is not None:
    #     phi_csa = phi(csa)
    # else:
    #     phi_csa = csa
    # unq_phi_csa = np.unique(phi_csa, axis = 0)
    # cosine_sim_mat = cosine_similarity(unq_phi_csa)
    # np.fill_diagonal(cosine_sim_mat, np.nan)
    # mean_sim = np.nanmean(cosine_sim_mat, axis=1)
    # print (mean_sim)
    # print(np.mean(mean_sim), np.std(mean_sim), np.max(mean_sim), np.min(mean_sim))

    #init_phi_ranks = utils.get_init_sa_stats(ope_data, pie, phi = phi)['srank']

    #utils.nearest_neighbors(ope_data, pie = pie, phi = phi, enc_name = enc_name)
    #utils.high_d_plot_fa(ope_data, pie = pie, phi = phi, fname = enc_name, typ = 'umap')

    #utils.high_d_plot(ope_data, pie = pie, phi = phi, fname = enc_name, typ = 'tsne', ground_weights = None)

    bc_stats = utils.bc_measure(phi, mdp.observation_space.shape[0] + mdp.action_space.shape[0] if FLAGS.encoder_name == 'identity' else FLAGS.phi_outdim,\
        tr_ope_data, gamma, lam = FLAGS.bc_measure_logdet)

    _, ope_metrics = utils_algo.run_experiment_ope(FLAGS, ope_method, ope_data, gamma, mdp, pie, phi = phi, enc_name = enc_name)

    final_ope_err = phi_metrics['phi_ope_error'][FLAGS.phi_epochs]
    final_ope_est = phi_metrics['phi_ope_ret'][FLAGS.phi_epochs]
    
    # final_ope_err = ope_evaluator.evaluate(phi, ope_metrics['weights'], metric_type = 'error')
    # final_ope_est = ope_evaluator.evaluate(phi, ope_metrics['weights'], metric_type = 'return')

    final_phi_stats = ope_metrics['phi_lspe_stats']
    final_phi_stats.update(bc_stats)

    q_csa, q_target = None, None
    if 'fqe' in FLAGS.encoder_name:
        q_csa, q_target = utils.get_q_values(tr_ope_data, pie, critic, gamma = FLAGS.gamma)

    algo_name = '{}_{}'.format(FLAGS.ope_method, FLAGS.encoder_name)
    summary = {
        'env': FLAGS.env_name,
        'env_type': FLAGS.env_type,
        'results': {
            FLAGS.dataset_name: {} # no name for the dataset
        },
        'encoder_name': FLAGS.encoder_name,
        'seed': seed,
        'hp': {
            'Q_lr': FLAGS.Q_lr,
            'phi_lr': FLAGS.phi_lr,
            'phi_hidden_dim': FLAGS.phi_hidden_dim,
            'phi_num_hidden_layers': FLAGS.phi_num_hidden_layers,
            'phi_outdim': FLAGS.phi_outdim,
            'beta': FLAGS.beta,
            'M_lr': FLAGS.M_lr,
            'Q_num_hidden_layers': FLAGS.Q_num_hidden_layers,
            'Q_hidden_dim': FLAGS.Q_hidden_dim,
            'Q_act_function': FLAGS.Q_act_function,
            'Q_loss_function': FLAGS.Q_loss_function,
            'Q_reset_opt_freq': FLAGS.Q_reset_opt_freq,
            'Q_adam_beta': FLAGS.Q_adam_beta,
            'Q_use_target_net': FLAGS.Q_use_target_net,
            'Q_soft_update_tau': FLAGS.Q_soft_update_tau,
            'Q_hard_update_freq': FLAGS.Q_hard_update_freq,
            'Q_norm_type': FLAGS.Q_norm_type,
            'Q_target_update_type': FLAGS.Q_target_update_type,
            'bcrl_norm_selfpred': FLAGS.bcrl_norm_selfpred,
            'phi_hard_update_freq': FLAGS.phi_hard_update_freq,
            'aux_alpha': FLAGS.aux_alpha,
            'bc_measure_logdet': FLAGS.bc_measure_logdet
        },
        'oracle_est': oracle_est_ret,
        'rand_est': rand_est_ret,
        'gamma': FLAGS.gamma,
        'normalize_states': FLAGS.normalize_states,
        'normalize_rewards': FLAGS.normalize_rewards
    }

    summary['results'][FLAGS.dataset_name][algo_name] = final_phi_stats
    summary['results'][FLAGS.dataset_name][algo_name]['oracle_est_ret'] = oracle_est_ret
    summary['results'][FLAGS.dataset_name][algo_name]['rand_est_ret'] = rand_est_ret
    summary['results'][FLAGS.dataset_name][algo_name]['ope_error'] = final_ope_err
    summary['results'][FLAGS.dataset_name][algo_name]['ope_est'] = final_ope_est
    #summary['results'][FLAGS.dataset_name][algo_name]['bellman_error'] = ope_metrics['bellman_residual'][FLAGS.epochs]
    summary['results'][FLAGS.dataset_name][algo_name]['tr_losses'] = phi_metrics['phi_tr_loss'][FLAGS.phi_epochs] if 'phi_tr_loss' in phi_metrics else 0
    summary['results'][FLAGS.dataset_name][algo_name]['brm_losses'] = phi_metrics['phi_brm_loss'][FLAGS.phi_epochs] if len(phi_metrics['phi_brm_loss']) in phi_metrics else 0
    summary['results'][FLAGS.dataset_name][algo_name]['phi_tr_metrics'] = phi_metrics
    summary['results'][FLAGS.dataset_name][algo_name]['q_values'] = {
        'q_csa': q_csa,
        'q_target': q_target
    }
    print (summary)
    np.save(FLAGS.outfile, summary) 
