import torch
from torch import nn
import numpy as np
from policies import NeuralNetwork, CustomID, NNwLinear
import copy
import time
import estimators
import pdb
from utils import soft_target_update, get_phi_stats,\
    load_mini_batches, compute_inverse, reset_optimizer, PhiMetricTracker, get_err, clip_target

torch.autograd.set_detect_anomaly(True)

class ContinuousPrePhiFQE:
    def __init__(self,
        state_dims,
        action_dims,
        gamma,
        pie,
        abs_state_dims = None,
        abs_state_action_dim = None,
        q_hidden_dim = 32,
        q_hidden_layers = 1,
        activation = 'relu',
        Q_lr = 1e-5,
        q_reg = 0,
        phi = None,
        image_state = False,
        sa_phi = False,
        clip_target = False,
        tabular = False,
        loss_function = 'mse',
        reset_opt_freq = -1,
        adam_beta = -1,
        use_target_net = True,
        soft_update_tau = 5e-3,
        hard_update_freq = 5000,
        norm_type = 'none',
        target_update_type = 'soft'):

        self.state_dims = state_dims
        self.action_dims = action_dims
        self.abs_state_dims = abs_state_dims
        self.pie = pie
        self.gamma = gamma
        self.q_hidden_dim = q_hidden_dim
        self.q_hidden_layers = q_hidden_layers
        self.activation = activation
        self.Q_lr = Q_lr
        self.q_reg = q_reg
        self.sa_phi = sa_phi
        self.clip_target = clip_target
        self.tabular = tabular

        self.phi_defined = phi is not None

        q_input_dim = state_dims * action_dims if self.tabular else state_dims + action_dims
        if self.phi_defined:
            phi = phi.train(False)
            q_input_dim = abs_state_action_dim if self.sa_phi else abs_state_dims + action_dims
        q_output_dim = 1

        self.phi = phi if self.phi_defined else CustomID()

        if phi is None and image_state:
            return
        else:
        # if phi is defined use MLP with inputs as abs_state_dims
        # if phi is not defined but non-image-based then use MLP
            self.Q = NeuralNetwork(input_dims = q_input_dim,
                                output_dims = q_output_dim,
                                hidden_dim = q_hidden_dim,
                                hidden_layers = q_hidden_layers,
                                activation = self.activation,
                                norm_type = norm_type)

        self.reset_opt_freq = reset_opt_freq
        if adam_beta == -1:
            self.adam_beta1 = 0.9
            self.adam_beta2 = 0.999
        else:
            self.adam_beta1 = self.adam_beta2 = adam_beta
        self.Q_optimizer = torch.optim.Adam(self.Q.parameters(), lr = self.Q_lr, betas = (self.adam_beta1, self.adam_beta2))
        self.target_Q = copy.deepcopy(self.Q)
        self.loss_function = loss_function
        self.use_target_net = use_target_net
        self.soft_update_tau = soft_update_tau
        self.hard_update_freq = hard_update_freq
        self.target_update_type = target_update_type

        # hacky
        # if not using target network, fully update to match Q
        # since the "best" network returned is the target network
        # target network is not used for any bootstrapping just a placeholder
        # for Q to be returned
        if not self.use_target_net:
            self.soft_update_tau = 1.
            self.hard_update_freq = 1.

        q_config = f"""
        Q network config\n
        loss function: {self.loss_function}\n
        use target net: {self.use_target_net}\n
        target update type: {target_update_type}\n
        soft target update: {self.soft_update_tau}\n
        hard target update freq: {self.hard_update_freq}\n
        reset adam: {self.reset_opt_freq != -1}\n
        adam betas: {self.adam_beta1, self.adam_beta2}\n
        normalization type: {norm_type}\n
        """
        print (q_config, self.Q)

    def train(self, data, epochs = 2000, print_log = True):
        
        mini_batch_size = 2048
        num_workers = 6
        self.pmt = PhiMetricTracker(self.pie, self.gamma)
        self.best_Q = copy.deepcopy(self.target_Q)

        params = {'batch_size': mini_batch_size, 'shuffle': True, 'num_workers': num_workers}
        dataloader = torch.utils.data.DataLoader(data, **params)

        disc_ret_ests = {}

        for epoch in range(0, epochs + 1):
            for mb in dataloader:
                curr_sa = mb['curr_sa']
                next_sa = mb['next_sa']
                rewards = mb['rewards']
                terminal_masks = mb['terminal_masks']
                next_states = mb['next_state']

                obj, frac = self._td_loss(data, curr_sa, next_sa, next_states, rewards, terminal_masks)
                
                objs = {'total_obj': obj}
                if epoch > 0:
                    # clear gradients
                    self.Q_optimizer.zero_grad()
                
                    # compute gradients
                    try:
                        obj.backward()
                    except:
                        print ('error on backprop, breaking')
                        break

                    # processing gradients
                    nn.utils.clip_grad_value_(self.Q.parameters(), clip_value = 1.0)

                    # gradient step
                    self.Q_optimizer.step()

                    soft_target_update(self.Q, self.target_Q, tau = self.soft_update_tau)
            if self.reset_opt_freq != -1 and epoch % self.reset_opt_freq == 0:
                self.Q_optimizer = reset_optimizer(self.Q, self.Q_lr, self.adam_beta1, self.adam_beta2)

            # # if using soft, update optimizer (if true) at set frequency
            # # if using hard, update optimizer (if true) after target network updated
            # if self.target_update_type == 'soft':
            #     soft_target_update(self.Q, self.target_Q, tau = self.soft_update_tau)
            #     if self.reset_opt_freq != -1 and epoch % self.reset_opt_freq == 0:
            #         self.Q_optimizer = reset_optimizer(self.Q, self.Q_lr, self.adam_beta1, self.adam_beta2)
            # elif self.target_update_type == 'hard':
            #     if epoch % self.hard_update_freq == 0:
            #         soft_target_update(self.Q, self.target_Q, tau = 1.)
            #         if self.reset_opt_freq != -1:
            #             self.Q_optimizer = reset_optimizer(self.Q, self.Q_lr, self.adam_beta1, self.adam_beta2)

            self.best_Q = self.target_Q
            
            #if epoch % 1000 == 0 or epoch == epochs or epoch == 1:
            q_init_inputs = self._get_init_state_actions(data)
            trans_feats_info = self._get_penultimate_feat_stats(epoch, objs, data, curr_sa, next_states)
            
            #ret = data.unnormalize_rewards(np.mean(self.best_Q(q_init_inputs)))# / (1. - self.gamma)
            ret = np.mean(self.best_Q(q_init_inputs))
            disc_ret_ests[epoch] = ret 

            print ('epoch {}, disc ret: {}, feats info: {}, frac: {}'\
                .format(epoch, ret, trans_feats_info, frac))

        self.disc_ret_ests = disc_ret_ests

    def _td_loss(self, data, curr_sa, next_sa, next_s, rews, term_masks):
        curr_sa = self.phi(curr_sa)
        if self.tabular:
            next_sa = 0
            num_samps = 10
            for _ in range(num_samps):
                res = self.pie.sample_sa_features(next_s)
                next_sa += self.phi(res)
            next_sa = next_sa / num_samps
        else:
            next_sa = self.phi(next_sa)
        
        if self.use_target_net:
            q_curr_outputs = self.Q.forward(curr_sa).reshape(-1)
            with torch.no_grad():
                q_next_outputs = self.target_Q(next_sa).reshape(-1)
        else:
            combined_sa_nsa = np.vstack((curr_sa, next_sa))
            comb_output = self.Q.forward(combined_sa_nsa) # passing all together to compute batch stats (if using batch norm)
            q_curr_outputs = comb_output[:curr_sa.shape[0], :].reshape(-1)
            q_next_outputs = comb_output[curr_sa.shape[0]:, :].detach().reshape(-1) # stop gradient

        target = torch.Tensor(rews + self.gamma * term_masks * q_next_outputs)

        if self.clip_target:
            target = torch.clip(target,\
                min = data.min_reward / (1. - self.gamma),\
                max = data.max_reward / (1. - self.gamma))

        frac = torch.count_nonzero(torch.abs(q_curr_outputs - target) < 1) / q_curr_outputs.shape[0]

        if self.loss_function == 'huber':
            obj = torch.nn.functional.huber_loss(q_curr_outputs, target)
        elif self.loss_function == 'mse':
            obj = torch.nn.functional.mse_loss(q_curr_outputs, target)
        return obj, frac

    def _get_init_state_actions(self, data):
        if self.tabular:
            #q_init_inputs = torch.Tensor(data.init_state_actions)
            q_init_inputs = 0
            num_samps = 100
            for _ in range(num_samps):
                init_sa = self.pie.sample_sa_features(data.initial_states)
                q_init_inputs += self.phi(init_sa)
            q_init_inputs = q_init_inputs / num_samps
        else:
            init_ground_states = torch.Tensor(data.get_initial_states_samples(-1))
            sampled_actions = torch.Tensor(self.pie.batch_sample(data.unnormalize_states(init_ground_states)))
            q_init_inputs = torch.concat((init_ground_states, sampled_actions), axis = 1)
            if self.phi is not None:
                q_init_inputs = torch.Tensor(self.phi(q_init_inputs))
        return q_init_inputs

    def _get_penultimate_feat_stats(self, epoch, obj, data, curr_sa, next_states):
        with torch.no_grad():
            curr_sa = ground_curr_sa = torch.Tensor(curr_sa)
            if self.tabular:
                curr_sa = torch.Tensor(self.phi(curr_sa))
                next_sa = 0
                num_samps = 10
                for _ in range(num_samps):
                    res = self.pie.sample_sa_features(next_states)
                    next_sa += self.phi(res)
                next_sa = ground_next_sa = torch.Tensor(next_sa / num_samps)
            else:
                next_sampled_acts = torch.Tensor(self.pie.batch_sample(data.unnormalize_states(next_states)))
                next_sa = ground_next_sa = torch.concat((next_states, next_sampled_acts), axis = 1)
                if self.phi is not None:
                    curr_sa = torch.Tensor(self.phi(curr_sa))
                    next_sa = torch.Tensor(self.phi(next_sa))

            curr_sa_feats = self.best_Q.penultimate(curr_sa).numpy()
            next_sa_feats = self.best_Q.penultimate(next_sa).numpy()

            trans_feats_info = self.pmt.track_phi_training_stats(epoch, self.phi,\
                obj, ground_curr_sa.numpy(), ground_next_sa.numpy(), pie = self.pie, fqe = True, fqe_net = self.target_Q)
        return trans_feats_info

    def get_metrics(self):
        metrics = {
            'r_ests': self.disc_ret_ests,
        }
        metrics.update(self.pmt.metrics)
        return metrics

    def get_Q(self):
        return self.best_Q

    def get_phi(self):
        return self.phi

class ContinuousAuxPhiFQE:
    def __init__(self,
        state_dims,
        action_dims,
        gamma,
        pie,
        abs_state_dims = None,
        abs_state_action_dim = None,
        q_hidden_dim = 32,
        q_hidden_layers = 1,
        activation = 'relu',
        Q_lr = 1e-5,
        q_reg = 0,
        image_state = False,
        sa_phi = False,
        clip_target = False):

        self.state_dims = state_dims
        self.action_dims = action_dims
        self.abs_state_dims = abs_state_dims
        self.pie = pie
        self.gamma = gamma
        self.q_hidden_dim = q_hidden_dim
        self.q_hidden_layers = q_hidden_layers
        self.activation = activation
        self.Q_lr = Q_lr
        self.q_reg = q_reg
        self.sa_phi = sa_phi
        self.clip_target = clip_target

        self.tabular = False

        q_input_dim = state_dims * action_dims if self.tabular else state_dims + action_dims
        q_output_dim = 1

        if image_state:
            return
        else:
            self.Q = NeuralNetwork(input_dims = q_input_dim,
                                output_dims = q_output_dim,
                                hidden_dim = q_hidden_dim,
                                hidden_layers = q_hidden_layers,
                                activation = self.activation)

        self.Q_optimizer = torch.optim.AdamW(self.Q.parameters(), lr = self.Q_lr)#, weight_decay = 1e-5)
        self.target_Q = copy.deepcopy(self.Q)

    def train(self, data, epochs = 2000, print_log = True):
        
        mini_batch_size = 512
        min_obj = float('inf')
        best_epoch = -1

        self.rew_range = data.max_abs_reward_diff
        self.best_Q = copy.deepcopy(self.target_Q)
        disc_ret_ests = {}
        tr_losses = {}
        feat_dotprods = {}
        feat_dotprod_diffs = {}

        for epoch in range(1, epochs + 1):
            sub_data = data.get_samples(mini_batch_size)
            mini_batch = load_mini_batches(data, False, self.pie, mini_batch_size)
            curr_sa = mini_batch['curr_sa']
            next_sa = mini_batch['next_sa']
            rewards = mini_batch['rewards']
            terminal_masks = mini_batch['terminal_masks']

            other_curr_sa = mini_batch['other_curr_sa']
            other_next_sa = mini_batch['other_next_sa']
            other_rewards = mini_batch['other_rewards']
            other_terminal_masks = mini_batch['other_terminal_masks']

            objs = self._fqe_loss(curr_sa, next_sa, rewards, terminal_masks,\
                                    other_curr_sa, other_next_sa, other_rewards, other_terminal_masks)
            total_obj = objs['total_obj']

            # clear gradients
            self.Q_optimizer.zero_grad()
            
            # compute gradients
            total_obj.backward()

            # processing gradients
            nn.utils.clip_grad_value_(self.Q.parameters(), clip_value = 1.0)

            # gradient step
            self.Q_optimizer.step()

            # soft target update
            soft_target_update(self.Q, self.target_Q)

            total_obj = objs['total_obj'].item()
            fqe_obj = objs['obj'].item()
            rep_obj = objs['rep_obj'].item()

            self.best_Q = self.target_Q
            
            if epoch % 1000 == 0 or epoch == epochs or epoch == 1:
                init_ground_states = torch.Tensor(data.get_initial_states_samples(-1))
                sampled_actions = torch.Tensor(self.pie.batch_sample(data.unnormalize_states(init_ground_states)))
                q_init_inputs = torch.concat((init_ground_states, sampled_actions), axis = 1)

                with torch.no_grad():
                    th_q_init_inputs = torch.Tensor(q_init_inputs)
                    q_feats = self.best_Q.penultimate(th_q_init_inputs).numpy()
                    feats_info = get_phi_stats(q_feats)

                    curr_sa = torch.Tensor(curr_sa)
                    next_sa = torch.Tensor(next_sa)

                    curr_sa_feats = self.best_Q.penultimate(curr_sa).numpy()
                    next_sa_feats = self.best_Q.penultimate(next_sa).numpy()
                    trans_feats_info = get_phi_stats(curr_sa_feats, next_sa_feats, gamma = self.gamma)
                
                #ret = data.unnormalize_rewards(np.mean(self.best_Q(q_init_inputs)))# / (1. - self.gamma)
                ret = np.mean(self.best_Q(q_init_inputs))
                disc_ret_ests[epoch] = ret 
                tr_losses[epoch] = total_obj

                print ('epoch {}, loss {}, disc ret: {}, feats info: {}'\
                    .format(epoch, total_obj, ret, trans_feats_info))

        print ('best epoch {}, obj {}'.format(best_epoch, min_obj))
        self.disc_ret_ests = disc_ret_ests
        self.tr_losses = tr_losses

    def _fqe_loss(self, curr_sa, next_sa, rewards, terminal_masks,\
                other_curr_sa, other_next_sa, other_rewards, other_terminal_masks):

        pen_sa = self.Q.get_representation(curr_sa)
        q_curr_outputs = self.Q.output[-1].forward(pen_sa).reshape(-1)

        with torch.no_grad():
            pen_nextsa = self.target_Q.get_representation(next_sa)
            q_next_outputs = self.target_Q.output[-1](pen_nextsa).reshape(-1)# / (1. - self.gamma)
        
        target = torch.Tensor(torch.Tensor(rewards) \
            + self.gamma * torch.Tensor(terminal_masks) * q_next_outputs)

        if self.clip_target:
            target = torch.clip(target,\
                min = data.min_reward / (1. - self.gamma),\
                max = data.max_reward / (1. - self.gamma))

        obj = torch.nn.functional.huber_loss(q_curr_outputs, target)

        rep_obj = self._ksme_loss(pen_sa, pen_nextsa, rewards,\
                                other_curr_sa, other_next_sa, other_rewards)
        alpha = 0.5
        total_obj = (1. - alpha) * obj + alpha * rep_obj
        objs = {
            'obj': obj,
            'rep_obj': rep_obj,
            'total_obj': total_obj
        }
        return objs 

    def _ksme_loss(self, pen_sa, pen_nextsa, rewards,\
                other_curr_sa, other_next_sa, other_rewards):

        with torch.no_grad():
            pen_other_sa = self.target_Q.get_representation(other_curr_sa)
            pen_other_nextsa = self.target_Q.get_representation(other_next_sa)

        reward_dist = torch.Tensor(np.abs(rewards - other_rewards))
        phi_x = pen_sa
        phi_y = pen_other_sa
        curr_dotprod = torch.sum(phi_x * phi_y, axis = 1)# / (torch.linalg.norm(phi_x, axis = 1) *  torch.linalg.norm(phi_y, axis = 1))
        curr_Kxy = torch.abs(curr_dotprod)

        next_phi_x = pen_nextsa
        next_phi_y = pen_other_nextsa
        next_dotprod = torch.sum(next_phi_x * next_phi_y, axis = 1)# / (torch.linalg.norm(next_phi_x, axis = 1) *  torch.linalg.norm(next_phi_y, axis = 1))
        next_Kxy = torch.abs(next_dotprod)
        target = 1. - (reward_dist / self.rew_range) + self.gamma * next_Kxy
        #target = self.rew_range - reward_dist + self.gamma * next_Kxy

        if (torch.any(target < 0) or torch.any(self.rew_range - reward_dist) < 0):
            pdb.set_trace()

        rep_obj = torch.nn.functional.huber_loss(curr_Kxy, target)
        return rep_obj

    def get_metrics(self):
        metrics = {
            'r_ests': self.disc_ret_ests,
            'tr_losses': self.tr_losses,
            'feat_dotprods': self.feat_dotprods,
            'feat_dotprod_diffs': self.feat_dotprod_diffs
        }
        return metrics

    def get_Q(self):
        return self.best_Q

    def get_phi(self):
        return self.phi

class ContinuousPrePhiLinearFQE:
    def __init__(self,
        state_dims,
        action_dims,
        gamma,
        pie,
        abs_state_dims = None,
        abs_state_action_dim = None,
        phi = None,
        image_state = False,
        sa_phi = False,
        clip_target = False,
        tabular = False,
        lr = 1e-3):

        self.state_dims = state_dims
        self.action_dims = action_dims
        self.abs_state_dims = abs_state_dims
        self.pie = pie
        self.gamma = gamma
        self.sa_phi = sa_phi
        self.clip_target = clip_target
        self.tabular = tabular
        self.lr = lr

    def train(self, data, epochs = 2, phi = None, print_log = True):

        self.Q = NNwLinear(phi)
        self.target_Q = copy.deepcopy(self.Q)
        self.optimizer = torch.optim.AdamW(self.Q.parameters(), lr = self.lr, betas = (0.9, 0.9))

        num_workers = 6
        mini_batch_size = 2048
        self.best_Q = copy.deepcopy(self.target_Q)

        params = {'batch_size': mini_batch_size, 'shuffle': True, 'num_workers': num_workers}
        dataloader = torch.utils.data.DataLoader(data, **params)
        #w_weights = [param.detach() for param in self.best_Q.linear.parameters()][0].numpy().reshape(-1)
        #phi_weights = [param.detach() for param in self.best_Q.backbone.parameters()][0].numpy().reshape(-1)

        for epoch in range(0, epochs + 1):
            losses = []
            for idx, mb in enumerate(dataloader):
                curr_sa = mb['curr_sa']
                next_sa = mb['next_sa']
                rewards = mb['rewards']
                terminal_masks = mb['terminal_masks']

                objs = self._td_loss(data, rewards, curr_sa, next_sa, terminal_masks)
                total_obj = objs['total_obj']
                losses.append(total_obj)

                # start updates for epoch 1 onwards
                # clear gradients
                self.optimizer.zero_grad()
                
                # compute gradients
                try:
                    total_obj.backward()
                except:
                    print ('error on backprop, breaking')
                    break

                # processing gradients
                nn.utils.clip_grad_value_(self.Q.parameters(), clip_value = 1.0)
                
                # gradient step
                self.optimizer.step()
                #print ('mb ', idx)
            avg_loss = torch.Tensor(losses).mean()

            soft_target_update(self.Q.linear, self.target_Q.linear, tau = 1)

            # replacing with loss across whole epoch
            objs['total_obj'] = avg_loss

            # Note: stats, losses etc are based on the last batch sampled above
            self.best_Q = self.Q
            #w_weights = [param.detach() for param in self.best_Q.linear.parameters()][0].numpy().reshape(-1)
            #phi_weights = [param.detach() for param in self.best_Q.backbone.parameters()][0].numpy().reshape(-1)
            print (f'{epoch}, {avg_loss}')
        
        self.theta = [param.detach() for param in self.best_Q.linear.parameters()][0].numpy()

    def _td_loss(self, data, rews, curr_sa, next_sa, term_masks):

        csa_phi = torch.Tensor(self.Q.backbone(curr_sa)) # no gradients through fixed phi
        q_curr_outputs = self.Q.linear.forward(csa_phi).reshape(-1) # gradients only for linear layer
        with torch.no_grad():
            q_next_outputs = self.target_Q(next_sa).reshape(-1)
        
        target = torch.Tensor(rews + self.gamma * term_masks * q_next_outputs)

        obj = torch.nn.functional.mse_loss(q_curr_outputs, target, reduction = 'none')
        obj = obj.mean()
        total_obj = obj

        objs = {
            'obj': obj,
            'total_obj': total_obj
        }
        return objs

    def get_metrics(self):
        metrics = {
            'weights': self.theta
        }
        return metrics

    def get_theta(self):
        return self.theta.reshape(-1)

    def get_phi(self):
        return self.phi

class ContinuousPrePhiLSPE:
    def __init__(self,
        state_dims,
        action_dims,
        gamma,
        pie,
        abs_state_dims = None,
        abs_state_action_dim = None,
        phi = None,
        image_state = False,
        sa_phi = False,
        clip_target = False,
        tabular = False):

        self.state_dims = state_dims
        self.action_dims = action_dims
        self.abs_state_dims = abs_state_dims
        self.pie = pie
        self.gamma = gamma
        self.sa_phi = sa_phi
        self.clip_target = clip_target
        self.tabular = tabular

        self.phi_defined = phi is not None

        q_input_dim = state_dims * action_dims if self.tabular else state_dims + action_dims
        if self.phi_defined:
            phi = phi.train(False)
            q_input_dim = abs_state_action_dim if self.sa_phi else abs_state_dims + action_dims
        q_output_dim = 1
        self.q_input_dim = q_input_dim

        self.phi = phi if self.phi_defined else CustomID()

        if phi is None and image_state:
            return
        else:
            # if phi is defined use MLP with inputs as abs_state_dims
            # if phi is not defined but non-image-based then use MLP
            self.theta = np.zeros(q_input_dim)

    def train(self, data, epochs = 2000, num_action_samples = 20, print_log = True, phi = None):

        if phi is not None:
            phi = phi.train(False)
            self.phi = phi
        weight_norms = {}
        disc_ret_ests = {}
        bellman_residual = {}
        ope_init_errs = {}
        ope_sa_errs = {}
        mini_batch_size = 2048
        num_workers = 6
        params = {'batch_size': mini_batch_size, 'shuffle': False, 'num_workers': num_workers}
        dataloader = torch.utils.data.DataLoader(data, **params)
        #cov_sum = None
        n = 0
        phi_t, phi_tp1, reward, terminal = [], [], [], []
        for sub_data in dataloader:
            if self.tabular:
                curr_sa = torch.Tensor(sub_data['curr_sa'])
                rewards = torch.Tensor(sub_data['rewards'])
                terminal_masks = torch.Tensor(sub_data['terminal_masks'])
                next_states = None
            else:
                curr_sa = torch.Tensor(sub_data['curr_sa'])
                rewards = torch.Tensor(sub_data['rewards'])
                next_states = torch.Tensor(sub_data['next_state'])
                terminal_masks = torch.Tensor(sub_data['terminal_masks'])

            if self.sa_phi:
                curr_sa = torch.Tensor(self.phi(curr_sa))
                res = 0
                for _ in range(num_action_samples):
                    if self.tabular:
                        temp_next_sa = self.pie.sample_sa_features(sub_data['next_state'])
                    else:
                        next_sampled_acts = torch.Tensor(self.pie.batch_sample(data.unnormalize_states(next_states)))
                        temp_next_sa = torch.concat((next_states, next_sampled_acts), axis = 1)
                    res += torch.Tensor(self.phi(temp_next_sa))
                res = res / num_action_samples
                next_sa = res
            else:
                curr_sa = torch.Tensor(np.concatenate((self.phi(curr_states), curr_actions), axis = 1))
                next_sa = torch.concat((self.phi(next_states), next_sampled_acts), axis = 1)

            # make them unit norm vectors
            #curr_sa = curr_sa / torch.linalg.norm(curr_sa, axis = 1, keepdim=True)
            #next_sa = next_sa / torch.linalg.norm(next_sa, axis = 1, keepdim=True)
            
            # if cov_sum is None:
            #     cov_sum = torch.zeros(curr_sa.size(-1), curr_sa.size(-1))
            # cov_sum += torch.bmm(
            #     curr_sa.unsqueeze(-1), curr_sa.unsqueeze(-2)
            # ).sum(dim=0)

            # Store
            phi_t.append(curr_sa)
            phi_tp1.append(next_sa)
            reward.append(rewards)
            terminal.append(terminal_masks)
            n += curr_sa.shape[0]

        phi_t = torch.cat(phi_t, dim=0)
        phi_tp1 = torch.cat(phi_tp1, dim=0)
        reward = torch.cat(reward, dim=0)
        terminal = torch.cat(terminal, dim=0)
        reward = reward.reshape((-1, 1))
        terminal = terminal.reshape((-1, 1))
        assert reward.dim() == 2 and reward.shape[-1] == 1

        # standardize by converting to numpy
        phi_t = phi_t.numpy()
        phi_tp1 = phi_tp1.numpy()
        reward = reward.numpy()
        terminal = terminal.numpy()
        cov = np.matmul(phi_t.T, phi_t) / n
        inv_cov = compute_inverse(cov)

        self.phi_stats = get_phi_stats(phi_t, phi_tp1, gamma = self.gamma, terminals = terminal)
        print ('in lspe train method ', self.phi_stats)
        
        # initialize theta
        self.theta = np.random.rand(self.theta.shape[0]).reshape(-1, 1)
        ncov = self.gamma * np.matmul(phi_t.T, terminal * phi_tp1) / n
        rew = np.matmul(phi_t.T, reward) / n
        for epoch in range(1, epochs + 1):
            self.theta = inv_cov @ (rew + ncov @ self.theta)
            self.theta = np.clip(self.theta, -1e10, 1e10)
            # prev_q = (phi_tp1 @ self.theta).unsqueeze(-1)
            # y = reward + self.gamma * terminal * prev_q
            # #if self.clip_target:
            # #    y = torch.clamp(y, data.min_reward / (1 - self.gamma), data.max_reward / (1 - self.gamma))
            # pred = torch.mean(phi_t * y, dim=0)
            # self.theta = inv_cov @ pred
            
            if epoch % 100 == 0 or epoch == epochs or epoch == 1:
                q_init_inputs = 0
                for _ in range(num_action_samples):
                    if self.tabular:
                        init_sa = self.pie.sample_sa_features(data.initial_states)
                    else:
                        init_ground_states = torch.Tensor(data.get_initial_states_samples(mini_batch_size))
                        next_sampled_acts = torch.Tensor(self.pie.batch_sample(data.unnormalize_states(init_ground_states)))
                        init_sa = torch.concat((init_ground_states, next_sampled_acts), axis = 1)
                    q_init_inputs += torch.Tensor(self.phi(init_sa))
                
                q_init_inputs = q_init_inputs / num_action_samples
                q_init_inputs = q_init_inputs.numpy()
                #pie_sa = torch.Tensor(data.pie_path_sa)

                prev_q = (phi_tp1 @ self.theta)
                y = reward + self.gamma * terminal * prev_q
                residual = np.mean(np.square(phi_t @ self.theta - y))
                theta_norm = np.linalg.norm(self.theta)
                ret_init = np.mean(q_init_inputs @ self.theta)
                ret_sa = np.zeros_like(self.theta)#(pie_sa @ self.theta).numpy()

                # # basic clipping to avoid overflow
                # ret_init = np.clip(ret_init, -1e4, 1e4)
                # ret_sa = np.clip(ret_sa, -1e4, 1e4)
                # if self.clip_target:
                # actual clipping specifically for policy evaluation

                # allowing room (with 2x) to calculate if diverged if outside valid range
                #ret_init = clip_target(True, ret_init, data.min_reward, data.max_reward, self.gamma)
                #ret_sa = clip_target(True, ret_sa, data.min_reward, data.max_reward, self.gamma)

                disc_ret_ests[epoch] = ret_init
                
                ope_init_error = 0#get_err([data.pie_val], [ret_init])
                ope_sa_error = 0#get_err(data.pie_path_sa_vals, ret_sa)
                ope_init_errs[epoch] = ope_init_error
                ope_sa_errs[epoch] = ope_sa_error

                bellman_residual[epoch] = residual
                weight_norms[epoch] = theta_norm
                print ('epoch {}, residual {}, disc ret: {}, theta norm {}'.format(epoch, residual, ret_init, theta_norm))

        self.disc_ret_ests = disc_ret_ests
        self.bellman_residual = bellman_residual
        self.weight_norms = weight_norms
        self.ope_init_errs = ope_init_errs
        self.ope_sa_errs = ope_sa_errs
        #print (self.theta.reshape(-1))

    def get_metrics(self):
        metrics = {
            'r_ests': self.disc_ret_ests,
            'bellman_residual': self.bellman_residual,
            'weights': self.theta.reshape(-1),
            'phi_lspe_stats': self.phi_stats,
            'weight_norms': self.weight_norms,
            'init_errs': self.ope_init_errs,
            'sa_errs': self.ope_sa_errs
        }
        return metrics

    def get_theta(self):
        return self.theta.reshape(-1)

    def get_phi(self):
        return self.phi

class ContinuousPrePhiTD:
    def __init__(self,
        state_dims,
        action_dims,
        gamma,
        pie,
        abs_state_dims = None,
        abs_state_action_dim = None,
        phi = None,
        image_state = False,
        sa_phi = False,
        clip_target = False,
        tabular = False):

        self.state_dims = state_dims
        self.action_dims = action_dims
        self.abs_state_dims = abs_state_dims
        self.pie = pie
        self.gamma = gamma
        self.sa_phi = sa_phi
        self.clip_target = clip_target
        self.tabular = tabular

        self.phi_defined = phi is not None

        q_input_dim = state_dims * action_dims if self.tabular else state_dims + action_dims
        if self.phi_defined:
            phi = phi.train(False)
            q_input_dim = abs_state_action_dim if self.sa_phi else abs_state_dims + action_dims
        q_output_dim = 1
        self.q_input_dim = q_input_dim

        self.phi = phi if self.phi_defined else CustomID()

        if phi is None and image_state:
            return
        else:
            # if phi is defined use MLP with inputs as abs_state_dims
            # if phi is not defined but non-image-based then use MLP
            self.theta = torch.zeros(q_input_dim)

    def train(self, data, epochs = 2000, num_action_samples = 20, print_log = True, phi = None):

        if phi is not None:
            phi = phi.train(False)
            self.phi = phi
        disc_ret_ests = {}
        tr_losses = {}
        mini_batch_size = 2048
        num_workers = 6
        params = {'batch_size': mini_batch_size, 'shuffle': False, 'num_workers': num_workers}
        dataloader = torch.utils.data.DataLoader(data, **params)
        cov_sum = None
        ncov_sum = None
        r_sum = None
        n = 0
        phi_t, phi_tp1, reward, terminal = [], [], [], []
        for sub_data in dataloader:
            if self.tabular:
                curr_sa = torch.Tensor(sub_data['curr_sa'])
                rewards = torch.Tensor(sub_data['rewards'])
                terminal_masks = torch.Tensor(sub_data['terminal_masks'])
                next_states = None
            else:
                curr_sa = torch.Tensor(sub_data['curr_sa'])
                rewards = torch.Tensor(sub_data['rewards'])
                next_states = torch.Tensor(sub_data['next_state'])
                terminal_masks = torch.Tensor(sub_data['terminal_masks'])

            if self.sa_phi:
                curr_sa = torch.Tensor(self.phi(curr_sa))
                res = 0
                for _ in range(num_action_samples):
                    if self.tabular:
                        temp_next_sa = self.pie.sample_sa_features(sub_data['next_state'])
                    else:
                        next_sampled_acts = torch.Tensor(self.pie.batch_sample(data.unnormalize_states(next_states)))
                        temp_next_sa = torch.concat((next_states, next_sampled_acts), axis = 1)
                    res += torch.Tensor(self.phi(temp_next_sa))
                res = res / num_action_samples
                next_sa = res
            else:
                curr_sa = torch.Tensor(np.concatenate((self.phi(curr_states), curr_actions), axis = 1))
                next_sa = torch.concat((self.phi(next_states), next_sampled_acts), axis = 1)

            # make them unit norm vectors
            #curr_sa = curr_sa / torch.linalg.norm(curr_sa, axis = 1, keepdim=True)
            #next_sa = next_sa / torch.linalg.norm(next_sa, axis = 1, keepdim=True)
            
            if cov_sum is None:
                cov_sum = torch.zeros(curr_sa.size(-1), curr_sa.size(-1))
                ncov_sum = torch.zeros(curr_sa.size(-1), next_sa.size(-1))
                r_sum = torch.zeros(curr_sa.size(-1), 1)
            
            cov_sum += torch.bmm(
                curr_sa.unsqueeze(-1), curr_sa.unsqueeze(-2)
            ).sum(dim=0)
            ncov_sum += torch.bmm(
                curr_sa.unsqueeze(-1), next_sa.unsqueeze(-2)
            ).sum(dim=0)

            unsq_r = rewards.unsqueeze(-1).unsqueeze(-1)
            r_sum += torch.bmm(
                curr_sa.unsqueeze(-1), unsq_r
            ).sum(dim=0)

            # Store
            phi_t.append(curr_sa)
            phi_tp1.append(next_sa)
            reward.append(rewards)
            terminal.append(terminal_masks)
            n += curr_sa.shape[0]

        cov = cov_sum / n
        ncov = ncov_sum / n
        rsum = r_sum / n
        A = cov - self.gamma * ncov
        rsum = rsum.reshape(-1)
        inv_cov = compute_inverse(cov)
        phi_t = torch.cat(phi_t, dim=0)
        phi_tp1 = torch.cat(phi_tp1, dim=0)
        reward = torch.cat(reward, dim=0)
        terminal = torch.cat(terminal, dim=0)
        reward = reward.reshape((-1, 1))
        terminal = terminal.reshape((-1, 1))
        assert reward.dim() == 2 and reward.shape[-1] == 1

        phi_stats = get_phi_stats(phi_t.numpy(), phi_tp1.numpy(), gamma = self.gamma)
        print ('in lspe train method ', phi_stats)
        
        lr = 1
        # initialize at 0
        self.theta.zero_()
        #self.theta = torch.rand(self.theta.shape)
        pdb.set_trace()
        for epoch in range(1, epochs + 1):
            target =  A @ self.theta
            td_err = target - rsum

            self.theta = self.theta - lr  * td_err

            lr = 0.99 * lr
            lr = max(lr, 1e-3)

            if epoch % 100 == 0 or epoch == epochs or epoch == 1:

                if self.tabular:
                    #q_init_inputs = torch.Tensor(data.init_state_actions)
                    q_init_inputs = 0
                    for _ in range(num_action_samples):
                        init_sa = self.pie.sample_sa_features(data.initial_states)
                        q_init_inputs += torch.Tensor(self.phi(init_sa))
                    q_init_inputs = q_init_inputs / num_action_samples
                else:
                    init_ground_states = torch.Tensor(data.get_initial_states_samples(mini_batch_size))
                    next_sampled_acts = torch.Tensor(self.pie.batch_sample(data.unnormalize_states(init_ground_states)))
                    q_init_inputs = torch.concat((init_ground_states, next_sampled_acts), axis = 1)
                    q_init_inputs = torch.Tensor(self.phi(q_init_inputs))
                #if self.sa_phi and not self.tabular:
                    
                #else:
                #    next_sa = torch.concat((self.phi(init_ground_states), next_sampled_acts), axis = 1)

                feat_dotprod = torch.sum(phi_tp1 * phi_t, axis = 1).mean().item()

                # init_ground_states = torch.Tensor(data.get_initial_states_samples(mini_batch_size))
                # sampled_actions = torch.Tensor(self.pie.batch_sample(data.unnormalize_states(init_ground_states)))

                # if self.sa_phi:
                #     q_init_inputs = torch.Tensor(self.phi(torch.concat((init_ground_states, sampled_actions), axis = 1)))
                # else:
                #     q_init_inputs = torch.concat((self.phi(init_ground_states), sampled_actions), axis = 1)
                
                phi_stats = get_phi_stats(q_init_inputs.numpy())
                residual = -1#torch.mean((phi_t @ self.theta - y.squeeze(1)) ** 2).item()
                theta_norm = torch.norm(self.theta)
                ret = torch.mean(q_init_inputs @ self.theta).item()
                disc_ret_ests[epoch] = ret
                tr_losses[epoch] = residual
                print ('epoch {}, loss {}, disc ret: {}, rank of init states: {}, feat dotprod: {}, theta norm {}'.format(epoch, residual, ret, phi_stats['srank'], feat_dotprod, theta_norm))

        self.disc_ret_ests = disc_ret_ests
        self.tr_losses = tr_losses

    def get_metrics(self):
        metrics = {
            'r_ests': self.disc_ret_ests,
            'tr_losses': self.tr_losses,
            'weights': self.theta.numpy()
        }
        return metrics

    def get_theta(self):
        return self.theta

    def get_phi(self):
        return self.phi

class ContinuousPrePhiLSTD:
    def __init__(self,
        state_dims,
        action_dims,
        gamma,
        pie,
        abs_state_dims = None,
        abs_state_action_dim = None,
        phi = None,
        image_state = False,
        sa_phi = False,
        clip_target = False,
        tabular = False):

        self.state_dims = state_dims
        self.action_dims = action_dims
        self.abs_state_dims = abs_state_dims
        self.pie = pie
        self.gamma = gamma
        self.sa_phi = sa_phi
        self.clip_target = clip_target
        self.tabular = tabular

        self.phi_defined = phi is not None

        q_input_dim = state_dims * action_dims if self.tabular else state_dims + action_dims
        if self.phi_defined:
            phi = phi.train(False)
            q_input_dim = abs_state_action_dim if self.sa_phi else abs_state_dims + action_dims
        q_output_dim = 1
        self.q_input_dim = q_input_dim

        self.phi = phi if self.phi_defined else CustomID()

        if phi is None and image_state:
            return
        else:
            # if phi is defined use MLP with inputs as abs_state_dims
            # if phi is not defined but non-image-based then use MLP
            self.theta = torch.zeros(q_input_dim)

    def train(self, data, epochs = 2000, print_log = True):

        weight_norms = {}
        disc_ret_ests = {}
        tr_losses = {}
        mini_batch_size = 2048
        num_workers = 6
        params = {'batch_size': mini_batch_size, 'shuffle': True, 'num_workers': num_workers}
        dataloader = torch.utils.data.DataLoader(data, **params)
        cov_sum = None
        n = 0
        A, b = None, None
        for sub_data in dataloader:
            if self.tabular:
                curr_sa = torch.Tensor(sub_data['curr_state_actions'])
                rewards = torch.Tensor(sub_data['rewards'])
                terminal_masks = torch.Tensor(sub_data['terminal_masks'])
                next_states = None
            else:
                curr_sa = torch.Tensor(sub_data['curr_sa'])
                rewards = torch.Tensor(sub_data['rewards'])
                next_states = torch.Tensor(sub_data['next_state'])
                terminal_masks = torch.Tensor(sub_data['terminal_masks'])

            if self.sa_phi:
                curr_sa = torch.Tensor(self.phi(curr_sa))
                num_samps = 100
                res = 0
                for _ in range(num_samps):
                    if self.tabular:
                        temp_next_sa = self.pie.sample_sa_features(sub_data['next_states'])
                    else:
                        next_sampled_acts = torch.Tensor(self.pie.batch_sample(data.unnormalize_states(next_states)))
                        temp_next_sa = torch.concat((next_states, next_sampled_acts), axis = 1)
                    res += torch.Tensor(self.phi(temp_next_sa))
                res = res / num_samps
                next_sa = res
            else:
                curr_sa = torch.Tensor(np.concatenate((self.phi(curr_states), curr_actions), axis = 1))
                next_sa = torch.concat((self.phi(next_states), next_sampled_acts), axis = 1)
            
            if A is None:
                A = torch.zeros(curr_sa.size(-1), curr_sa.size(-1))
                b = torch.zeros(curr_sa.size(-1))

            td_diff = curr_sa - self.gamma * terminal_masks[:, None] * res#next_sa

            A += torch.bmm(
                curr_sa.unsqueeze(-1), td_diff.unsqueeze(-2)
            ).sum(dim=0)
            b += (curr_sa * rewards[:, None]).sum(dim=0)
            n += curr_sa.shape[0]

        A = A / n
        b = b /n
        inv_A = torch.linalg.pinv(A)
        self.theta = torch.matmul(inv_A, b)

        if self.tabular:
            q_init_inputs = torch.Tensor(data.init_state_actions)
        else:
            init_ground_states = torch.Tensor(data.get_initial_states_samples(mini_batch_size))
            next_sampled_acts = torch.Tensor(self.pie.batch_sample(data.unnormalize_states(init_ground_states)))
            q_init_inputs = torch.concat((init_ground_states, next_sampled_acts), axis = 1)

        if self.sa_phi:
            q_init_inputs = 0
            num_samps = 50
            for _ in range(num_samps):
                if self.tabular:
                    init_sa = self.pie.sample_sa_features(data.initial_states)
                else:
                    init_ground_states = torch.Tensor(data.get_initial_states_samples(mini_batch_size))
                    next_sampled_acts = torch.Tensor(self.pie.batch_sample(data.unnormalize_states(init_ground_states)))
                    init_sa = torch.concat((init_ground_states, next_sampled_acts), axis = 1)
                q_init_inputs += torch.Tensor(self.phi(init_sa))
            q_init_inputs = q_init_inputs / num_samps

        ret = torch.mean(q_init_inputs @ self.theta)
        a_trace = torch.trace(A).item()

        print ('disc ret: {}, A matrix trace: {}'.format(ret, a_trace))
        self.disc_ret_ests = []
        self.tr_losses = []

    def get_metrics(self):
        metrics = {
            'r_ests': self.disc_ret_ests,
            'tr_losses': self.tr_losses,
            'weights': self.theta.numpy()
        }
        return metrics

    def get_theta(self):
        return self.theta

    def get_phi(self):
        return self.phi
