from typing import Union, Callable
import numpy as np
import torch as th
# import wandb
from copy import deepcopy
from rl.rl_algorithm import RLAlgorithm
from rl.utils.utils import get_tau
import rl.successor_features.pytorch_util as ptu
from rl.successor_features.risk import distortion_de

class DGPI(RLAlgorithm):

    def __init__(self,
                 env,
                 algorithm_constructor: Callable,
                 log: bool = True,
                 project_name: str = 'dgpi',
                 experiment_name: str = 'dgpi',
                 device: Union[th.device, str] = 'auto'):
        super(DGPI, self).__init__(env, device)

        self.algorithm_constructor = algorithm_constructor
        self.policies = []
        self.tasks = []

        # self.log = log
        # if self.log:
        #     self.setup_wandb(project_name, experiment_name)

    def eval(self, obs, w, tau_hat, presum_tau, return_policy_index=False, exclude=None) -> int:
        #hasattr(object,name)函数用于判断是否包含对应的属性:object-对象;name-字符串，属性名
        #返回值： 如果对象有该属性返回True,否则返回False
        if not hasattr(self.policies[0], 'dpsi_table'):
            if isinstance(obs, np.ndarray):
                obs = th.tensor(obs).float().to(self.device)
                # actions = th.tensor(actions).float().to(self.device)
                w = th.tensor(w).float().to(self.device)

            # if obs.ndim == 1:
            #     obs = obs[None, ...]
            # if actions.ndim == 0:
            #     actions = th.unsqueeze(actions, 0)[None, ...]
            
            # print(obs)
            q_values = th.stack([policy.q_values(obs, w, tau_hat, presum_tau) for policy in self.policies]) # [p, 9, 256]
            # print(q_values.shape)
            bq, bp = th.max(q_values, dim=0)      # bq [9, 256]; bp [9, 256]
            # if self.dgpi is not None:
            max_bq, actions = th.max(bq, dim=1)         # [256]
            bp = th.squeeze(bp, dim=0)
            a = actions[0].detach().long().item()
            if return_policy_index:
                return a, bp[a].detach().cpu().long().item()
            return a
        else:
            q_vals = np.stack([policy.q_values(obs, w, tau_hat, presum_tau) for policy in self.policies if policy is not exclude])
            policy_index, action = np.unravel_index(np.argmax(q_vals), q_vals.shape)
            if return_policy_index:
                return action, policy_index
            return action
    
    def delete_policies(self, delete_indx):
        for i in sorted(delete_indx, reverse=True):
            self.policies.pop(i)
            self.tasks.pop(i)


    def learn(self, w, total_timesteps, writer, total_episodes=None, reset_num_timesteps=False, eval_env=None, eval_freq=1000, use_dgpi=True, reset_learning_starts=True, new_policy=True, reuse_value_ind=None):
        if new_policy:
            new_policy = self.algorithm_constructor()
            self.policies.append(new_policy)
        self.tasks.append(w)
        
        self.policies[-1].dgpi = self if use_dgpi else None

        # if self.log:
        #     self.policies[-1].log = self.log
        #     self.policies[-1].writer = self.writer
        #     wandb.config.update(self.policies[-1].get_config())

        if len(self.policies) > 1:
            self.policies[-1].num_timesteps = self.policies[-2].num_timesteps
            self.policies[-1].num_episodes = self.policies[-2].num_episodes
            if reset_learning_starts:
                self.policies[-1].learning_starts = self.policies[-2].num_timesteps  # to reset exploration schedule

            if reuse_value_ind is not None:
                if hasattr(self.policies[-1], 'q_table'):
                    self.policies[-1].q_table = deepcopy(self.policies[reuse_value_ind].q_table)
                else:
                    self.policies[-1].psi_net.load_state_dict(self.policies[reuse_value_ind].psi_net.state_dict())
                    self.policies[-1].target_psi_net.load_state_dict(self.policies[reuse_value_ind].psi_net.state_dict())

            self.policies[-1].replay_buffer = self.policies[-2].replay_buffer

        target_fp = self.policies[-1].learn(w=w,
                                total_timesteps=total_timesteps, 
                                writer=writer,
                                total_episodes=total_episodes,
                                reset_num_timesteps=reset_num_timesteps,
                                eval_env=eval_env,
                                eval_freq=eval_freq)
        return target_fp

    @property
    def gamma(self):
        return self.policies[0].gamma

    def train(self):
        pass

    def get_config(self) -> dict:
        if len(self.policies) > 0:
            return self.policies[0].get_config()
        return {}
