
from numpy.lib.function_base import select
import acquisitions
import ipdb
from d3rlpy.algos import DiscreteCQL

import numpy as np
import torch
import torch.nn.functional as F
from typing import Any, Dict, Optional, Sequence

from d3rlpy.iterators import RandomIterator, RoundIterator, TransitionIterator

from d3rlpy.dataset import TransitionMiniBatch

from d3rlpy.constants import ActionSpace
from d3rlpy.models.optimizers import AdamFactory, OptimizerFactory
from d3rlpy.algos.torch.cql_impl import DiscreteCQLImpl

from collections import defaultdict
from d3rlpy.logger import LOG, D3RLPyLogger
# from d3rlpy.algos.dqn import DoubleDQN
from tqdm.auto import tqdm
import gym
from utils import set_seed_everywhere
import seaborn as sns
from typing import (
    Any,
    Callable,
    DefaultDict,
    Dict,
    Generator,
    List,
    Optional,
    Sequence,
    Tuple,
    Union,
    cast,
)
from d3rlpy.constants import (
    CONTINUOUS_ACTION_SPACE_MISMATCH_ERROR,
    DISCRETE_ACTION_SPACE_MISMATCH_ERROR,
    IMPL_NOT_INITIALIZED_ERROR,
    ActionSpace,
)

from d3rlpy.dataset import Episode, MDPDataset, Transition, TransitionMiniBatch

from pathlib import Path

import wandb
import time, os, fnmatch, shutil

import pandas as pd
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
import torch.nn as nn

import d3rlpy
from d3rlpy.algos import DiscreteBC
from d3rlpy.metrics.scorer import evaluate_on_environment
from utils import load_models

from d3rlpy.torch_utility import TorchMiniBatch, torch_api, train_api

from torch.optim.lr_scheduler import CosineAnnealingLR
from PIL import Image
from d3rlpy.preprocessing.stack import StackedObservation

class ActiveCQL(DiscreteCQL):

    def fit(
        self,
        dataset: Union[List[Episode], MDPDataset],
        n_epochs: Optional[int] = None,
        n_steps: Optional[int] = None,
        n_steps_per_epoch: int = 10000,
        save_metrics: bool = True,
        experiment_name: Optional[str] = None,
        with_timestamp: bool = True,
        logdir: str = "d3rlpy_logs",
        verbose: bool = True,
        show_progress: bool = True,
        tensorboard_dir: Optional[str] = None,
        eval_episodes: Optional[List[Episode]] = None,
        save_interval: int = 10,
        scorers: Optional[
            Dict[str, Callable[[Any, List[Episode]], float]]
        ] = None,
        shuffle: bool = True,
        callback: Optional[Callable[["LearnableBase", int, int], None]] = None,
        net_args=None,
        env=None,
        num_norand_eps=-1,
        sampled_idxs=None,
    ) -> List[Tuple[int, Dict[str, float]]]:

        results = list(
            self.fitter_active(
                dataset,
                n_epochs,
                n_steps,
                n_steps_per_epoch,
                save_metrics,
                experiment_name,
                with_timestamp,
                logdir,
                verbose,
                show_progress,
                tensorboard_dir,
                eval_episodes,
                save_interval,
                scorers,
                shuffle,
                callback,
                net_args,
                env,
                num_norand_eps,
                sampled_idxs,
            )
        )
        return results

    def manual_curr(self, trans, net_args=None):
        batch = trans.observations.shape[0]
        vals_list = [92, 108, 125, 133, 192, 239, 447, 466, 285, 303, 317, 350, 364, 400, 410, 48, 61, 467, 478, 524, 544, 129, 599,]
        mask = (np.isin(trans.ep_ids, vals_list)).reshape(-1)
        scores = torch.ones((batch), dtype=torch.float64)
        scores[mask] = self.net_args.manual_curr_coeff
        return scores, None, None

    def relo_curr(self, transitions, tdper_with_cons=False):
        score_list = []
        real_score_list = []

        cons_list = []
        targets = []
        fwd_batch_size = 16384                                       
        n_batch = len(transitions) // fwd_batch_size + 1
        total_idxs = list(range(len(transitions)))
        for j in range(n_batch):
            batch_idxs = total_idxs[j * fwd_batch_size : (j + 1) * fwd_batch_size]
            transitions_set = list(map(transitions.__getitem__, batch_idxs))
            batch = TransitionMiniBatch(transitions_set)

            torch_batch = TorchMiniBatch(
                batch,
                self.impl.device,                    ## get all these attributes
                scaler=self.scaler,
                action_scaler=self.action_scaler,
                reward_scaler=self.reward_scaler,
            )

            target_action_target = False #TODO: arg-ify this
            with torch.no_grad():
                next_actions = self.impl._targ_q_func(torch_batch.next_observations)
                max_action = next_actions.argmax(dim=1)
                q_tpn = self.impl._targ_q_func.compute_target(
                    torch_batch.next_observations,
                    max_action,
                    reduction="mean",
                )
                targets.append(q_tpn.reshape(-1))
            scores = self.impl._q_func.compute_error(
                observations=torch_batch.observations,
                actions=torch_batch.actions.long(),
                rewards=torch_batch.rewards,
                target=q_tpn,
                terminals=torch_batch.terminals,
                gamma=self._gamma ** torch_batch.n_steps,
            ).view(-1).double()

            irr_scores = self.impl._targ_q_func.compute_error(
                observations=torch_batch.observations,
                actions=torch_batch.actions.long(),
                rewards=torch_batch.rewards,
                target=q_tpn,
                terminals=torch_batch.terminals,
                gamma=self._gamma ** torch_batch.n_steps,
            ).view(-1).double()

            real_scores = scores-irr_scores
            real_score_list.append(F.relu(real_scores))
            conservative_loss = self.impl._compute_conservative_loss(
                torch_batch.observations, torch_batch.actions.long()
            )
            if tdper_with_cons:
                scores = scores + self.impl._alpha * conservative_loss.view(-1).double()

            score_list.append(scores)
            cons_list.append(conservative_loss)

        targets = torch.cat(targets)
        scores = torch.cat(score_list)
        real_scores = torch.cat(real_score_list)
        conservative_loss = torch.cat(cons_list)
        return scores, targets, None, conservative_loss

    def PER_curr(self, transitions, tdper_with_cons=False):
        score_list = []
        cons_list = []
        targets = []
        fwd_batch_size = 16384                                       
        n_batch = len(transitions) // fwd_batch_size + 1
        total_idxs = list(range(len(transitions)))
        for j in range(n_batch):
            batch_idxs = total_idxs[j * fwd_batch_size : (j + 1) * fwd_batch_size]
            transitions_set = list(map(transitions.__getitem__, batch_idxs))
            batch = TransitionMiniBatch(transitions_set)

            torch_batch = TorchMiniBatch(
                batch,
                self.impl.device,                    ## get all these attributes
                scaler=self.scaler,
                action_scaler=self.action_scaler,
                reward_scaler=self.reward_scaler,
            )

            target_action_target = False #TODO: arg-ify this
            with torch.no_grad():
                if target_action_target:
                    next_actions = self.impl._targ_q_func(torch_batch.next_observations)
                    max_action = next_actions.argmax(dim=1)
                else:
                    max_action = self.impl._predict_best_action(torch_batch.next_observations)

                q_tpn = self.impl._targ_q_func.compute_target(
                    torch_batch.next_observations,
                    max_action,
                    reduction="mean",
                )
                targets.append(q_tpn.reshape(-1))
            scores = self.impl._q_func.compute_error(
                observations=torch_batch.observations,
                actions=torch_batch.actions.long(),
                rewards=torch_batch.rewards,
                target=q_tpn,
                terminals=torch_batch.terminals,
                gamma=self._gamma ** torch_batch.n_steps,
            ).view(-1).double()

            conservative_loss = self.impl._compute_conservative_loss(
                torch_batch.observations, torch_batch.actions.long()
            )
            if tdper_with_cons:
                scores = scores + self.impl._alpha * conservative_loss.view(-1).double()

            score_list.append(scores)
            cons_list.append(conservative_loss)

        targets = torch.cat(targets)
        scores = torch.cat(score_list)
        conservative_loss = torch.cat(cons_list)
        return scores, targets, None, conservative_loss

    def process_scores(self, scores, meta_data, dataset_batch, select_eps):
        # ipdb.set_trace()
        terminals = dataset_batch.terminals
        terminal_indices = (np.arange(terminals.shape[0])[terminals.reshape(-1).astype(int)==1]+1)[:-1]
        scores_eps = np.split(scores, terminal_indices)
        batch = terminals.shape[0]
        new_scores = np.ones((batch), dtype=np.float64)

        # ep_score_sums = np.array([scores_ep.sum() for scores_ep in scores_eps])  ## not invariant to len
        # ep_score_sums = np.array([scores_ep.mean() for scores_ep in scores_eps])
        # max_set_scores = np.concatenate([np.array([scores_ep.max()]*len(scores_ep)) for scores_ep in scores_eps])

        ep_score_sums = np.array([scores_ep.max() for scores_ep in scores_eps])
        high_score_ep_ind = np.argpartition(ep_score_sums, -select_eps)[-select_eps:]

        mask = (np.isin(dataset_batch.ep_ids, high_score_ep_ind)).reshape(-1)
        new_scores[mask] = self.net_args.manual_curr_coeff

        return new_scores # max_set_scores # #  

    def fitter_active(
        self,
        dataset: Union[List[Episode], List[Transition], MDPDataset],
        n_epochs: Optional[int] = None,
        n_steps: Optional[int] = None,
        n_steps_per_epoch: int = 10000,
        save_metrics: bool = True,
        experiment_name: Optional[str] = None,
        with_timestamp: bool = True,
        logdir: str = "d3rlpy_logs",
        verbose: bool = True,
        show_progress: bool = True,
        tensorboard_dir: Optional[str] = None,
        eval_episodes: Optional[List[Episode]] = None,
        save_interval: int = 10,
        scorers: Optional[
            Dict[str, Callable[[Any, List[Episode]], float]]
        ] = None,
        shuffle: bool = True,
        callback: Optional[Callable[["LearnableBase", int, int], None]] = None,
        net_args=None,
        env=None,
        num_norand_eps=-1,
        sampled_idxs=None,
    ) -> Generator[Tuple[int, Dict[str, float]], None, None]:


        logdir =  net_args.datapath + logdir
        transitions = []
        if isinstance(dataset[0], d3rlpy.dataset.Episode):
            for ep_itr, episode in enumerate(dataset):
                transitions += episode.transitions
        else:
            raise ValueError(f"invalid dataset type: {type(dataset)}")

        # setup logger
        logger = self._prepare_logger(
            save_metrics,
            experiment_name,
            with_timestamp,
            logdir,
            verbose,
            tensorboard_dir,
        )

        # add reference to active logger to algo class during fit
        self._active_logger = logger

        # initialize scaler
        if self._scaler:
            LOG.debug("Fitting scaler...", scaler=self._scaler.get_type())
            self._scaler.fit(transitions)

        # initialize action scaler
        if self._action_scaler:
            LOG.debug(
                "Fitting action scaler...",
                action_scaler=self._action_scaler.get_type(),
            )
            self._action_scaler.fit(transitions)

        # initialize reward scaler
        if self._reward_scaler:
            LOG.debug(
                "Fitting reward scaler...",
                reward_scaler=self._reward_scaler.get_type(),
            )
            self._reward_scaler.fit(transitions)

        action_size = transitions[0].get_action_size()
        print('action sz: ', action_size)
        # instantiate implementation
        if self._impl is None:
            LOG.debug("Building models...")
            transition = transitions[0]                                   ## change this, iterator.transitions[0]
            observation_shape = tuple(transition.get_observation_shape())
            self.create_impl(
                self._process_observation_shape(observation_shape), action_size
            )
            LOG.debug("Models have been built.")
        else:
            LOG.warning("Skip building models since they're already built.")

        qval_range = action_size*len(self._impl.q_function.q_funcs)

        # save hyperparameters
        self.save_params(logger)

        # refresh evaluation metrics
        self._eval_results = defaultdict(list)

        # refresh loss history
        self._loss_history = defaultdict(list)

        self.net_args = net_args

        batch_size = net_args.batch_size*3 if net_args.blur_traj else net_args.batch_size

        # props, _ = load_models(dataset, transitions, net_args, normal=not net_args.confused, batch_size=batch_size)
        props = None
        aq_func = net_args.aqfunc
        
        num_bkwds = net_args.num_bkwds

        if n_epochs is None:
            n_epochs = int(n_steps / n_steps_per_epoch)

        dataset_size = len(transitions)

        fwd_batch_size = 16384                                       
        n_batch = len(transitions) // fwd_batch_size + 1
        total_idxs = list(range(len(transitions)))
        actions_trans_list = []
        pick_meta_data = np.zeros((dataset_size, 3))


        for j in range(n_batch):
            batch_idxs = total_idxs[j * fwd_batch_size : (j + 1) * fwd_batch_size]
            transitions_set = list(map(transitions.__getitem__, batch_idxs))
            trans_1 = TransitionMiniBatch(transitions_set)
            actions_trans = torch.tensor(trans_1.actions, dtype=torch.int64).to('cuda')
            actions_trans_list.append(actions_trans)
            pick_meta_data[batch_idxs, 0] = trans_1.ep_ids.reshape(-1)
            pick_meta_data[batch_idxs, 1] = trans_1.tr_ids.reshape(-1)

        dataset_batch = TransitionMiniBatch(transitions)

        if net_args.sampler_manual_curr:
            man_scores, scale1, scale2 = self.manual_curr(
                dataset_batch,
                net_args,   
            )

        # pick_meta_data[:, 2] = 1 
        # mask = pick_meta_data[:, 0]<num_norand_eps
        # pick_meta_data[mask, 2] = 0.
        all_actions = torch.cat(actions_trans_list)

        curriculum = np.zeros((n_steps, batch_size, qval_range+4)) 
        pick_freqs = torch.zeros(dataset_size, n_epochs)
        action_freqs = np.zeros((action_size, n_epochs), dtype=np.float64)
        subsample_ratio = 2

        # Define the custom x axis metric
        wandb.define_metric("custom_step")
        # Define which metrics to plot against that x-axis
        wandb.define_metric("chosen_rewards", step_metric='custom_step')
        wandb.define_metric("mu", step_metric='custom_step')
        wandb.define_metric("pi", step_metric='custom_step')
        wandb.define_metric("chosen_mu", step_metric='custom_step')
        wandb.define_metric("chosen_pi", step_metric='custom_step')

        print('Saved to: ', logger.logdir)

        idx_start = np.random.choice(range(dataset_size), replace=False, size=batch_size)
        batch_start = [transitions[i] for i in idx_start]
        ctr = 0

        self.impl.eps_list = [92, 108, 125, 133, 192, 239, 447, 466, 285, 303, 317, 350, 364, 400, 410, 48, 61, 467, 478, 524, 544, 599, 129]
        self.eps_list = self.impl.eps_list
        self.indep_ensemble = net_args.indep_ensemble
        self.impl.indep_ensemble = net_args.indep_ensemble
        self.impl.bootstrap_ens = net_args.bootstrap_ens
        self.impl.flooding_loss = net_args.flooding_loss
        self.impl.clip_grad = net_args.clip_grad
        self.impl.traffic = net_args.traffic

        # self.impl.norand_eps_last_num = num_norand_eps
        # self.impl.downweighting = net_args.downweighting                                      ### remove extra flags
        self.impl.augment = net_args.augment
        self.impl.down_factor = net_args.down_factor

        self.mean_consistency = []
        step_iter = 0
        total_step = 0
        offset_iter = 0
        self.max_env_reward_so_far = -10000.

        active_alpha = net_args.active_weights_power

        for epoch in range(1, n_epochs + 1):

            self.epoch = epoch
            self.impl.epoch = epoch
            ipdb.set_trace()
            if epoch<net_args.gradual_period:
                aq_func = 'random'
                active_alpha = 1.0
            else:
                aq_func = net_args.aqfunc
                active_alpha = net_args.active_weights_power

            # dict to add incremental mean losses to epoch
            epoch_loss = defaultdict(list)
            range_gen = tqdm(
                range(int((n_steps_per_epoch*1.0)/num_bkwds)),
                disable=not show_progress,
                desc=f"Epoch {int(epoch)}/{n_epochs}",
                )

            for itr in range_gen:

                with logger.measure_time("step"):
                    
                    with torch.no_grad():
                        if aq_func=='random' or 'tdper' in aq_func: 
                            mu = torch.zeros(len(self._impl.q_function.q_funcs), len(transitions), transitions[0].get_action_size()).to('cuda')
                        else:
                            outputs = []
                            for j in range(n_batch):
                                batch_idxs = total_idxs[j * fwd_batch_size : (j + 1) * fwd_batch_size]
                                transitions_set = list(map(transitions.__getitem__, batch_idxs))
                                trans_1 = TransitionMiniBatch(transitions_set)
                                obs = torch.Tensor(trans_1.observations).to('cuda')
                                mu = self._impl._q_func(
                                    self._scaler.transform(obs) if self._scaler else obs,
                                    reduction='none')
                                                                    
                                outputs.append(mu)
                            mu = torch.cat(outputs, dim=1)

                        mean_pred_actions = mu.mean(axis=0).argmax(axis=-1)
                        max_ens_wise_action = mu.argmax(dim=2).transpose(0, 1)
                        all_actions_rp = all_actions.reshape(-1, 1).repeat(1, mu.shape[0])

                        data_uneq_qmean = (all_actions.reshape(-1)==mean_pred_actions.reshape(-1)).sum()/(1.0*mean_pred_actions.shape[0])  

                        data_ens_rebellion = (all_actions_rp==max_ens_wise_action).sum(axis=-1).float().mean()
                        self.mean_consistency.append(data_ens_rebellion)

                        # compute scores
                        if aq_func=='tdper':
                            scores, _, _, _ = self.PER_curr(transitions, net_args.tdper_with_cons)
                            scale1 = None
                            scale2 = None
                        elif aq_func=='relo_tdper':
                            scores, _, _, _ = self.relo_curr(transitions, net_args.tdper_with_cons) # remove the flag?
                            scale1 = None
                            scale2 = None
                        elif net_args.sampler_manual_curr:
                            scores = man_scores
                        else:
                            # compute scores when using mu, pi or mu-pi, random
                            scores, scale1, scale2 = acquisitions.FUNCTIONS[aq_func](mu, all_actions, props)
                            scores = scores + scores.min().abs()

                        if aq_func!='random' and net_args.episodic: # and not net_args.sampler_manual_curr:
                            scores = self.process_scores(scores.cpu().numpy(), pick_meta_data, dataset_batch, select_eps=net_args.select_eps)
                        else:
                            scores = scores.cpu().numpy()
                        powered = scores**active_alpha
                        p = (powered / (powered.sum()*1.0)).astype('float64')

                        pred_actions = mu.argmax(axis=-1).transpose(0, 1)

                        actions_rp = mean_pred_actions.reshape(-1, 1).repeat(1, mu.shape[0])
                        ens_rebellion = (actions_rp==pred_actions).sum(axis=-1).float().mean()

                    rand_eps_select = 0
                    norand_eps_select = 0
                    for i_bkwd in range(num_bkwds):
                        idx_all = np.random.choice(
                            range(len(p)),
                            replace=False,
                            p=p,
                            size=int(batch_size/3) if net_args.blur_traj else batch_size,
                        )
                        idx_start = idx_all
                        if net_args.blur_traj:
                            idx_start = np.append(idx_start, np.clip(idx_all-1, a_min=0, a_max=None))
                            idx_start = np.append(idx_start, np.clip(idx_all+1, a_min=None, a_max=len(p)-1))
                        batch_start = [transitions[i] for i in idx_start]

                        pick_freqs[idx_start, epoch-1] += 1
                        # val_to_check = 0 if net_args.reverse_expt else 1
                        # rand_eps_select += (pick_meta_data[idx_start, 2]==val_to_check).sum()                               #######
                        # norand_eps_select +=  (pick_meta_data[idx_start, 2]==(1-val_to_check)).sum()

                        selected_batch_trans = TransitionMiniBatch(batch_start)
                        with logger.measure_time("algorithm_update"):
                            loss = self.update(selected_batch_trans)   

                        step_iter += 1                                      
                        

                    # record metrics
                    for name, val in loss.items():
                        logger.add_metric(name, val)
                        epoch_loss[name].append(val)

                    # update progress postfix with losses
                    if itr % 10 == 0:
                        mean_loss = {
                            k: np.mean(v) for k, v in epoch_loss.items()
                        }
                        range_gen.set_postfix(mean_loss)

                    intra_epoch_log_dict = {
                            # "chosen_rewards": selected_batch_trans.rewards,
                            # "mu": mu.mean().cpu().item(),                       
                            # "data_action_consistency": data_ens_rebellion.item(),
                            # "policy_agreement": ens_rebellion.item(),
                            "data_uneq_qmean": data_uneq_qmean.item(),
                            # "rand_eps": rand_eps_select,                                                #######
                            # "norand_eps": norand_eps_select,
                    }

                    wandb.log(
                        intra_epoch_log_dict,
                        step=(max(0, epoch-1)*n_steps_per_epoch) + itr,
                    )
                    
                    ######################################

                total_step += 1

                # call callback if given
                if callback:
                    callback(self, epoch, total_step) # , self.scheduler

            # save loss to loss history dict
            self._loss_history["epoch"].append(epoch)
            self._loss_history["step"].append(total_step)
            for name, vals in epoch_loss.items():
                if vals:
                    self._loss_history[name].append(np.mean(vals))

            # if scorers and eval_episodes:
                # self._evaluate(eval_episodes, scorers, logger, epoch=epoch)
            
            rew_dict = None
            for name, scorer in scorers.items():
                # evaluation with test data
                test_score = scorer(self, eval_episodes)
                if name=='environment':
                    total_reward = []
                    rew_dict = test_score
                    for key, value in test_score.items():
                        logger.add_metric(key, value)                                       ## add to wandb 
                        total_reward.append(value)
                    logger.add_metric('environment', np.mean(total_reward))
                else:
                    # logging metrics
                    logger.add_metric(name, test_score)
                # store metric locally
                if test_score is not None:
                    self._eval_results[name].append(test_score)

            metrics = logger.commit(epoch, total_step)

            # save model parameters
            if epoch % save_interval == 0:
                logger.save_model(total_step, self)

            log_dict = {
                "loss": metrics["loss"],
                "reward": metrics["environment"], 
                "advantage": metrics["advantage"],
                "td_error": metrics["td_error"],
                "value_scale": metrics["value_scale"],
                "td_loss": np.mean(self.impl.td_losses),
                "cons_loss": np.mean(self.impl.cons_losses),         
            }

            for key, value in rew_dict.items():
                log_dict['reward_' + key] = value
            
            wandb.log(
                log_dict,
                step=epoch*n_steps_per_epoch,
            )

            # self.mean_consistency.clear()                         ## fill and log this                        
            self.impl.td_losses.clear()
            self.impl.cons_losses.clear()
            # self.impl.td_losses_full.clear()
            # self.impl.cons_losses_full.clear()

            yield epoch, metrics

        # drop reference to active logger since out of fit there is no active logger
        self._active_logger = None
