import torch
import numpy as np
import pdb
import argparse
import random
import yaml

import gymnasium as g
import estimators
import utils
from learn_phi import OffPolicySA, BCRL, KernelROPE, FQE, FQEAux, RandomFeatures, Krylov,\
    Schur, ConstantFeatures, KernelLihong, QpieBisim, PCAFeatures, TCL, Recon, RandomMatrixProjection
from fqe import ContinuousPrePhiFQE, ContinuousPrePhiLSPE, ContinuousAuxPhiFQE,\
    ContinuousPrePhiLSTD, ContinuousPrePhiTD, ContinuousPrePhiLinearFQE
from behavior_dataset import Dataset, PWDataset

def generate_dataset_objs(FLAGS, dataset, pie, tabular = False):
    # off policy data for OPE
    # all the data
    if FLAGS.pw_dataset:
        ope_data = PWDataset(dataset, normalize_states = FLAGS.normalize_states,\
            normalize_rewards = FLAGS.normalize_rewards,
            normalize_state_actions = FLAGS.normalize_state_actions,
            pie = pie,
            tabular = tabular)
    else:
        ope_data = Dataset(dataset, normalize_states = FLAGS.normalize_states,\
            normalize_rewards = FLAGS.normalize_rewards,
            normalize_state_actions = FLAGS.normalize_state_actions,
            pie = pie,
            tabular = tabular,
            skip_rate = FLAGS.skip_rate)

    if tabular:
        tr_ope_data, test_ope_data = ope_data, None
    else:
        # split the data used for training encoder
        tr_ope_data, test_ope_data = utils.split_dataset(dataset, tr_set_fraction = FLAGS.tr_set_fraction)
        tr_ope_data = Dataset(tr_ope_data, normalize_states = FLAGS.normalize_states,\
            normalize_rewards = FLAGS.normalize_rewards, pie = pie, skip_rate = FLAGS.skip_rate)
        if test_ope_data is not None:
            test_ope_data = Dataset(test_ope_data, normalize_states = FLAGS.normalize_states,\
                normalize_rewards = FLAGS.normalize_rewards, pie = pie, skip_rate = FLAGS.skip_rate)
    
    return ope_data, tr_ope_data, test_ope_data

def train_encoder(FLAGS, tr_data, test_data, mdp, gamma, encoder_name = None, pie = None, tabular = False, q_pie_values = None, ope_evaluator = None):
    if 'rope' == encoder_name:
        enc = OffPolicySA(ground_state_dims = mdp.observation_space.shape[0], abs_state_action_dims = FLAGS.phi_outdim, action_dims = mdp.action_space.shape[0],\
                                hidden_dim = FLAGS.phi_hidden_dim, hidden_layers = FLAGS.phi_num_hidden_layers, activation = FLAGS.phi_act_function, final_activation = None,\
                                lr = FLAGS.phi_lr, gamma = gamma, pie = pie, image_state = FLAGS.image_state, tabular = tabular,\
                                mdp = mdp, beta = FLAGS.beta, loss_function = FLAGS.rep_loss_function, norm_type = FLAGS.phi_norm_type)
    elif 'krope' == encoder_name:
        enc = KernelROPE(ground_state_dims = mdp.observation_space.shape[0], abs_state_action_dims = FLAGS.phi_outdim, action_dims = mdp.action_space.shape[0],\
                                hidden_dim = FLAGS.phi_hidden_dim, hidden_layers = FLAGS.phi_num_hidden_layers, activation = FLAGS.phi_act_function, final_activation = None,\
                                lr = FLAGS.phi_lr, gamma = gamma, pie = pie, image_state = FLAGS.image_state, mdp = mdp, tabular = tabular,\
                                norm_type = FLAGS.phi_norm_type, soft_update_tau = FLAGS.phi_soft_update_tau, hard_update_freq = FLAGS.phi_hard_update_freq)
    elif 'bcrl' in encoder_name and 'fqe' not in encoder_name:
        bcrl_type = 'both'
        if 'rew' in encoder_name or 'lat' in encoder_name:
            bcrl_type = encoder_name[5:] # assumes naming is bcrl-X
        enc = BCRL(ground_state_dims = mdp.observation_space.shape[0], abs_state_action_dims = FLAGS.phi_outdim, action_dims = mdp.action_space.shape[0], bcrl_type = bcrl_type,\
                                hidden_dim = FLAGS.phi_hidden_dim, hidden_layers = FLAGS.phi_num_hidden_layers, activation = FLAGS.phi_act_function, final_activation = None,\
                                phi_lr = FLAGS.phi_lr, M_lr = FLAGS.M_lr, gamma = gamma, pie = pie, mdp = mdp, tabular = tabular,\
                                norm_type = FLAGS.phi_norm_type, logdet_coeff = FLAGS.bcrl_logdet, norm_selfpred = FLAGS.bcrl_norm_selfpred, soft_update_tau = FLAGS.phi_soft_update_tau,\
                                hard_update_freq = FLAGS.phi_hard_update_freq)
    elif 'fqe' == encoder_name:
        enc = FQE(ground_state_dims = mdp.observation_space.shape[0], abs_state_action_dims = FLAGS.phi_outdim, action_dims = mdp.action_space.shape[0],\
                                hidden_dim = FLAGS.phi_hidden_dim, hidden_layers = FLAGS.phi_num_hidden_layers, activation = FLAGS.phi_act_function, final_activation = None,\
                                lr = FLAGS.phi_lr, gamma = gamma, pie = pie, image_state = FLAGS.image_state, mdp = mdp, tabular = tabular,\
                                norm_type = FLAGS.phi_norm_type, soft_update_tau = FLAGS.phi_soft_update_tau, hard_update_freq = FLAGS.phi_hard_update_freq,\
                                use_penultimate = FLAGS.phi_use_penultimate, ope_method = FLAGS.ope_method)   
    elif 'ktdloss' == encoder_name:
        enc = KernelLihong(ground_state_dims = mdp.observation_space.shape[0], abs_state_action_dims = FLAGS.phi_outdim, action_dims = mdp.action_space.shape[0],\
                                hidden_dim = FLAGS.phi_hidden_dim, hidden_layers = FLAGS.phi_num_hidden_layers, activation = FLAGS.phi_act_function, final_activation = None,\
                                lr = FLAGS.phi_lr, gamma = gamma, pie = pie, image_state = FLAGS.image_state, mdp = mdp, tabular = tabular,\
                                norm_type = FLAGS.phi_norm_type, soft_update_tau = FLAGS.phi_soft_update_tau, hard_update_freq = FLAGS.phi_hard_update_freq,\
                                use_penultimate = FLAGS.phi_use_penultimate)
    elif 'fqeaux' in encoder_name:
        aux_task = encoder_name[7:] # assumes naming is fqeaux-X
        enc = FQEAux(ground_state_dims = mdp.observation_space.shape[0], abs_state_action_dims = FLAGS.phi_outdim, action_dims = mdp.action_space.shape[0],\
                                hidden_dim = FLAGS.phi_hidden_dim, hidden_layers = FLAGS.phi_num_hidden_layers, activation = FLAGS.phi_act_function, final_activation = None,\
                                lr = FLAGS.phi_lr, M_lr = FLAGS.M_lr, logdet_coeff = FLAGS.bcrl_logdet, gamma = gamma, pie = pie, image_state = FLAGS.image_state, mdp = mdp, tabular = tabular,\
                                norm_type = FLAGS.phi_norm_type, soft_update_tau = FLAGS.phi_soft_update_tau, hard_update_freq = FLAGS.phi_hard_update_freq,\
                                use_penultimate = FLAGS.phi_use_penultimate, aux_task = aux_task, aux_alpha = FLAGS.aux_alpha, norm_selfpred = FLAGS.bcrl_norm_selfpred,\
                                krope_kernel = FLAGS.krope_kernel, krope_sigma = FLAGS.krope_sigma, ope_method = FLAGS.ope_method)
    elif 'tcl' == encoder_name:
        enc = TCL(ground_state_dims = mdp.observation_space.shape[0], abs_state_action_dims = FLAGS.phi_outdim, action_dims = mdp.action_space.shape[0],\
                                hidden_dim = FLAGS.phi_hidden_dim, hidden_layers = FLAGS.phi_num_hidden_layers, activation = FLAGS.phi_act_function, final_activation = None,\
                                lr = FLAGS.phi_lr, M_lr = FLAGS.M_lr, gamma = gamma, pie = pie, image_state = FLAGS.image_state, mdp = mdp, tabular = tabular,\
                                norm_type = FLAGS.phi_norm_type, soft_update_tau = FLAGS.phi_soft_update_tau, hard_update_freq = FLAGS.phi_hard_update_freq)
    elif 'recon' == encoder_name:
        enc = Recon(ground_state_dims = mdp.observation_space.shape[0], abs_state_action_dims = FLAGS.phi_outdim, action_dims = mdp.action_space.shape[0],\
                                hidden_dim = FLAGS.phi_hidden_dim, hidden_layers = FLAGS.phi_num_hidden_layers, activation = FLAGS.phi_act_function, final_activation = None,\
                                lr = FLAGS.phi_lr, M_lr = FLAGS.M_lr, gamma = gamma, pie = pie, image_state = FLAGS.image_state, mdp = mdp, tabular = tabular,\
                                norm_type = FLAGS.phi_norm_type, soft_update_tau = FLAGS.phi_soft_update_tau, hard_update_freq = FLAGS.phi_hard_update_freq)
    elif 'rmp' == encoder_name:
        enc = RandomMatrixProjection(ground_state_dims = mdp.observation_space.shape[0], abs_state_action_dims = FLAGS.phi_outdim, action_dims = mdp.action_space.shape[0],\
                                gamma = gamma, pie = pie, mdp = mdp, tabular = tabular)
    elif 'target-phi' in encoder_name:
        if 'sa' in encoder_name:
            enc = pie.pi.critic_target.qf0[0:-1] # skip last layer
        else:
            enc = pie.policy.value_net # TODO
        phi = enc
    elif 'random' in encoder_name:
        enc = RandomFeatures(mdp = mdp,  phi_outdim = FLAGS.phi_outdim)
        phi = enc
    elif 'constant' in encoder_name:
        enc = ConstantFeatures(mdp = mdp,  phi_outdim = FLAGS.phi_outdim)
        phi = enc
    elif 'krylov' in encoder_name:
        enc = Krylov(mdp = mdp, phi_outdim = FLAGS.phi_outdim, pi = pie, orthogonal = 'ortho' in encoder_name)
        phi = enc
    elif 'schur' in encoder_name:
        enc = Schur(mdp = mdp, phi_outdim = FLAGS.phi_outdim, pi = pie)
        phi = enc
    elif 'qpie' in encoder_name:
        enc = QpieBisim(mdp = mdp, q_pie_values = q_pie_values, pi = pie, eps = FLAGS.phi_qpie_eps)
        phi = enc
    elif 'pca' in encoder_name:
        enc = PCAFeatures(mdp = mdp,  phi_outdim = FLAGS.phi_outdim)
        phi = enc


    metrics = {}
    critic = None
    if 'target-phi' not in encoder_name\
        and 'random' not in encoder_name\
        and 'krylov' not in encoder_name\
        and 'schur' not in encoder_name\
        and 'constant' not in encoder_name\
        and 'qpie' not in encoder_name\
        and 'pca' not in encoder_name:
        #lspe = get_lspe_algo(FLAGS, gamma, mdp, pie, phi = enc.phi, enc_name = encoder_name, tabular = tabular)

        ope_algo = get_ope_algo(FLAGS, gamma, mdp, pie, phi = enc.phi, enc_name = encoder_name, tabular = tabular)
        # train encoder
        print ('training encoder')
        enc.train(tr_data, test_data, epochs = FLAGS.phi_epochs,\
            mini_batch_size = FLAGS.mini_batch_size,\
            ope_algo = ope_algo, ope_evaluator = ope_evaluator)
        phi = enc.get_phi()
        metrics = enc.get_metrics()
        critic = enc.get_critic()
    return phi, metrics, critic

def get_ope_algo(FLAGS, gamma, mdp, pie, phi = None, enc_name = None, tabular = False):
    sa_phi = True
    if enc_name == 'target-phi-sa':
        FLAGS.phi_outdim = phi[-2].out_features
    elif enc_name ==  'qpie':
        FLAGS.phi_outdim = phi.phi_outdim

    ope_algo = None
    if FLAGS.ope_method == 'lspe':
        ope_algo = ContinuousPrePhiLSPE(state_dims = mdp.observation_space.shape[0], action_dims = mdp.action_space.shape[0],
                                gamma = gamma, pie = pie,\
                                phi = phi, abs_state_dims = FLAGS.phi_outdim,\
                                image_state = FLAGS.image_state,
                                abs_state_action_dim = FLAGS.phi_outdim,
                                sa_phi = sa_phi,
                                clip_target = FLAGS.clip_target, tabular = tabular)
    # elif FLAGS.ope_method == 'fqe':
    #     ope_algo = ContinuousPrePhiLinearFQE(state_dims = mdp.observation_space.shape[0], action_dims = mdp.action_space.shape[0],
    #                             gamma = gamma, pie = pie,\
    #                             phi = phi, abs_state_dims = FLAGS.phi_outdim,\
    #                             image_state = FLAGS.image_state,
    #                             abs_state_action_dim = FLAGS.phi_outdim,
    #                             sa_phi = sa_phi,
    #                             clip_target = FLAGS.clip_target, tabular = tabular)
    return ope_algo

def run_experiment_ope(FLAGS, ope_method, ope_data, gamma, mdp, pie, phi = None, enc_name = None, tabular = False):
    # sa_phi = ('rope' in enc_name)\
    #     or ('identity' in enc_name)\
    #     or ('bcrl' in enc_name)\
    #     or ('fqe' in enc_name)\
    #     or ('random' in enc_name)\
    #     or ('krylov' in enc_name)\
    #     or ('schur' in enc_name)\
    #     or ('constant' in enc_name)\
    #     or ('ktdloss' in enc_name)
    # if 'fqe' in ope_method:
    #     if enc_name == 'target-phi-sa':
    #         FLAGS.phi_outdim = phi[-2].out_features
        
    #     if isinstance(mdp.action_space, g.spaces.Discrete):
    #         action_dims = mdp.action_space.n
    #     else:
    #         action_dims = mdp.action_space.shape[0]

    #     if 'krope' in ope_method:
    #         abs_fqe = ContinuousAuxPhiFQE(state_dims = mdp.observation_space.shape[0], action_dims = action_dims,
    #                                 gamma = gamma, pie = pie,\
    #                                 abs_state_dims = FLAGS.phi_outdim,\
    #                                 q_hidden_layers = FLAGS.Q_num_hidden_layers, q_hidden_dim = FLAGS.Q_hidden_dim,\
    #                                 activation = 'relu', 
    #                                 Q_lr = FLAGS.Q_lr, image_state = FLAGS.image_state,
    #                                 abs_state_action_dim = FLAGS.phi_outdim,
    #                                 sa_phi = sa_phi,
    #                                 clip_target = FLAGS.clip_target)
    #     else:
    #         abs_fqe = ContinuousPrePhiFQE(state_dims = mdp.observation_space.shape[0], action_dims = action_dims,
    #                                 gamma = gamma, pie = pie,\
    #                                 phi = phi, abs_state_dims = FLAGS.phi_outdim,\
    #                                 q_hidden_layers = FLAGS.Q_num_hidden_layers, q_hidden_dim = FLAGS.Q_hidden_dim,\
    #                                 activation = FLAGS.Q_act_function, 
    #                                 Q_lr = FLAGS.Q_lr, image_state = FLAGS.image_state,
    #                                 abs_state_action_dim = FLAGS.phi_outdim,
    #                                 sa_phi = sa_phi,
    #                                 clip_target = FLAGS.clip_target,
    #                                 loss_function = FLAGS.Q_loss_function, reset_opt_freq = FLAGS.Q_reset_opt_freq,
    #                                 adam_beta = FLAGS.Q_adam_beta, use_target_net = FLAGS.Q_use_target_net,
    #                                 soft_update_tau = FLAGS.Q_soft_update_tau, norm_type = FLAGS.Q_norm_type,
    #                                 target_update_type = FLAGS.Q_target_update_type, hard_update_freq = FLAGS.Q_hard_update_freq,
    #                                 tabular = tabular)

    #     abs_fqe.train(ope_data, epochs = FLAGS.epochs, print_log = FLAGS.print_log)
    #     metrics = abs_fqe.get_metrics()
    #     Q = abs_fqe.get_Q()
    #     phi = abs_fqe.get_phi()
    #     qestimator = estimators.QEstimate(Q, phi, pie, gamma, state_action_phi = sa_phi)
    #     # discounted return
    #     est = qestimator.estimate(ope_data)
    # elif 'lspe' in ope_method:
    # if enc_name == 'target-phi-sa':
    #     FLAGS.phi_outdim = phi[-2].out_features
    sa_phi = True
    if enc_name ==  'qpie':
        FLAGS.phi_outdim = phi.phi_outdim
    abs_lspe = ContinuousPrePhiLSPE(state_dims = mdp.observation_space.shape[0], action_dims = mdp.action_space.shape[0],
                            gamma = gamma, pie = pie,\
                            phi = phi, abs_state_dims = FLAGS.phi_outdim,\
                            image_state = FLAGS.image_state,
                            abs_state_action_dim = FLAGS.phi_outdim,
                            sa_phi = sa_phi,
                            clip_target = FLAGS.clip_target, tabular = tabular)
    abs_lspe.train(ope_data, epochs = FLAGS.epochs, print_log = FLAGS.print_log)
    metrics = abs_lspe.get_metrics()
    theta = abs_lspe.get_theta()
    phi = abs_lspe.get_phi()
    # lspe_estimator = estimators.LSPEEstimate(theta, phi, pie, gamma, clip_target = FLAGS.clip_target, state_action_phi = sa_phi, tabular = tabular)
    # # discounted return
    # est = lspe_estimator.estimate(ope_data)
        
    #print ('est value: {}'.format(est))
    return -1, metrics