import torch
import time
import numpy as np

from sgcrl.gym_helpers import Bot

SUBGOALS_PHASE = 1
GOAL_PHASE = 2

class DualPolicy(Bot):
    def __init__(self, low_level_policy_subgoal, low_level_policy_goal, high_level_policy, 
                 tokenizer, sos_token, eos_token, keys_to_tokenize, frame_relabellers=[],
                 use_obs_representation=True, only_goal_policy=False):
        super().__init__()
        self.low_level_policy_subgoal = low_level_policy_subgoal
        self.low_level_policy_goal = low_level_policy_goal
        self.high_level_policy = high_level_policy
        self.tokenizer = tokenizer
        self.use_obs_representation = use_obs_representation

        self.sos_token = sos_token
        self.eos_token = eos_token

        self.subgoals = []
        self.subgoals_history = []
        self.current_phase = None
        self.current_token = None
        self.goal = None

        self.labels = torch.ones(tokenizer._number_of_tokens + 1, len(list(tokenizer._tokens.keys())[0])) # tensor of token reprensentants
        for key, val in zip(tokenizer._tokens.keys(), tokenizer._tokens.values()):
            self.labels[val] = torch.tensor((key))

        self.keys_to_tokenize = keys_to_tokenize
        self.frame_relabellers = frame_relabellers

        self.obs_representation_of_tokens = torch.tensor(list(self.tokenizer._tokens.keys()), dtype=torch.float, device=tokenizer._device)

        self.only_goal_policy = only_goal_policy

    def cpu(self):
        return self.to("cpu")

    def to(self, device):
        self.low_level_policy_goal.to(device)
        self.low_level_policy_subgoal.to(device)
        self.high_level_policy.to(device)
        self.tokenizer.to(device)
        return self
    
    def eval(self):
        self.low_level_policy_goal.eval()
        self.low_level_policy_subgoal.eval()
        self.high_level_policy.eval()
        self.tokenizer.eval()

    def init_env(self, frame, track_token_history=False, n_plan_steps=False, argmax=False, pred_type='sample', k_samples=10, beam_width=10, choice='best_score'):
        # Get current token and goal token
        if max(['sensor' in key for key in list(frame.keys())]): # ugly way to check if this is godot
            s0 = torch.cat([frame[key] for key in self.keys_to_tokenize], dim=0)[[0,2]].unsqueeze(0)
            self.goal = frame["sensor/absolute_goal_position"][[0,2]].unsqueeze(0)
        else:
            s0 = torch.cat([frame[key] for key in self.keys_to_tokenize], dim=0).unsqueeze(0)
            self.goal = frame["goal"].unsqueeze(0)

        # Plan the subgoal path 
        s0_token = self.tokenizer.tokenize_tensor(s0)[0]
        goal_token = self.tokenizer.tokenize_tensor(self.goal)[0]
        print(f's0_token: {s0_token}, goal_token: {goal_token}, sos_soken: {self.sos_token}, eos_token: {self.eos_token}')
        inpt = torch.cat((goal_token, torch.tensor([[self.sos_token]]), s0_token), dim=1)
        self.goal_token = goal_token

        if pred_type == 'sample':
            if track_token_history:
                self.subgoals = self.high_level_policy.predict(inpt,token_history=self.subgoals_history,n_plan_steps=n_plan_steps,argmax=argmax).detach().cpu().numpy()[0]
            else:
                self.subgoals = self.high_level_policy.predict(inpt,n_plan_steps=n_plan_steps,argmax=argmax).detach().cpu().numpy()[0]
        elif pred_type == 'k_sample':
            if track_token_history:
                self.subgoals = self.high_level_policy.predict_k_sample(inpt,token_history=self.subgoals_history,n_plan_steps=n_plan_steps,argmax=False,k_samples=k_samples).detach().cpu().numpy()[0]
            else:
                self.subgoals = self.high_level_policy.predict_k_sample(inpt,n_plan_steps=n_plan_steps,argmax=False,k_samples=k_samples).detach().cpu().numpy()[0]
        elif pred_type == 'beam_search':
            if track_token_history:
                self.subgoals = self.high_level_policy.predict_beam(inpt,token_history=self.subgoals_history,n_plan_steps=n_plan_steps,argmax=False,beam_width=beam_width,choice=choice).detach().cpu().numpy()[0]
            else:
                self.subgoals = self.high_level_policy.predict_beam(inpt,n_plan_steps=n_plan_steps,argmax=False,beam_width=beam_width,choice=choice).detach().cpu().numpy()[0]
        else:
            raise ValueError(f'Unknown pred_type: {pred_type}')

        # Filter eos tokens
        if self.eos_token in self.subgoals:
            self.subgoals = self.subgoals[:2 + np.where(self.subgoals[2:] >= self.eos_token)[0][0]] # sos token + s0_token + every token that is before eos

        # Filter overtoken values
        # if not (self.subgoals >= len(self.tokenizer._tokens)).all() and len(np.where(self.subgoals[2:] >= len(self.tokenizer._tokens))[0]):
        #     self.subgoals = self.subgoals[:2 + np.where(self.subgoals[2:] >= len(self.tokenizer._tokens))[0][0]] # sos token + init token + every token that is before over-value

        self.subgoals = list(self.subgoals[2:][::-1]) # every token that is before eos: [goal_token,...,s1]
        self.current_token = s0_token.item()

        if self.current_token == goal_token.item():
            self.current_phase = GOAL_PHASE
        else:
            self.current_phase = SUBGOALS_PHASE

    def reset(self, seed):
        self.subgoals = []
        self.subgoals_history = []
        self.current_phase = None
        self.current_token = None
        self.goal = None
        self.replaning_step = 0

    @torch.no_grad()
    def _action(
        self, 
        frame,
        add_when_out=False,
        replan_when_out=False,
        replan_every_n_steps=False,
        add_when_out_goal=False,
        replan_when_out_goal=False,
        track_token_history=False,
        n_plan_steps=None,
        argmax=False,
        pred_type='sample', 
        k_samples=10, 
        beam_width=10, 
        choice='best_score',
        plot_infos=False,
        **kwargs):
        
        # Verify planning params
        assert replan_when_out + add_when_out < 2, 'One replanning strategy at a time.'
        assert add_when_out_goal + replan_when_out_goal < 2, 'One goal replanning strategy at a time.'

        # Set eval mode
        if 'eval' in kwargs:
            self.eval()

        # Set stochastic inference
        stochastic = False
        if 'stochastic' in kwargs:
            stochastic = kwargs["stochastic"]

        # Compute relevant features in frame
        frame = frame.copy()
        for relabeller in self.frame_relabellers:
            frame = relabeller(frame)

        # Initialize time dictionnary
        time_dictionnary = {'planning_time': 0, 'inference_time': 0}
        n_plannings = 0
        n_inferences = 0

        # If start, compute subgoals
        if self.current_phase == None:
            t = time.time()
            self.init_env(frame,track_token_history=track_token_history,n_plan_steps=n_plan_steps,argmax=argmax,pred_type=pred_type,k_samples=k_samples,beam_width=beam_width,choice=choice)
            time_dictionnary['planning_time'] += time.time() - t
            n_plannings += 1

        # get current token
        if max(['sensor' in key for key in list(frame.keys())]): # ugly way to check if this is godot
            token_frame = self.tokenizer.tokenize_tensor(torch.cat([frame[key] for key in self.keys_to_tokenize], dim=0)[[0,2]].unsqueeze(0))[0].item()
        else:
            token_frame = self.tokenizer.tokenize_tensor(torch.cat([frame[key] for key in self.keys_to_tokenize], dim=0).unsqueeze(0))[0].item()
        
        # if hasattr(self, 'only_goal_policy') and self.only_goal_policy:
        #     self.current_phase = GOAL_PHASE

        if self.current_phase == SUBGOALS_PHASE:
            
            # replan if replaning_step is too high
            if replan_every_n_steps and (self.replaning_step > replan_every_n_steps):
                t = time.time()
                self.init_env(frame,track_token_history=track_token_history,n_plan_steps=n_plan_steps,argmax=argmax,pred_type=pred_type,k_samples=k_samples,beam_width=beam_width,choice=choice)
                self.replaning_step = 0
                print('Replanned')
                time_dictionnary['planning_time'] += time.time() - t
                n_plannings += 1
            
            # check if we are in a new area
            if self.current_token != token_frame:
                print(f'SUBGOAL_PHASE - former token: {self.current_token}, new token: {token_frame}, remaining tokens: {self.subgoals}')
                self.subgoals_history.append(self.current_token)

                # check if we diverged from planning path
                if not token_frame in self.subgoals[-1:]:
                    if add_when_out:
                        t = time.time()
                        self.subgoals.append(self.current_token) # Add back last current token to backtrack
                        self.replaning_step = 0
                        print(f'Out of path, added {self.current_token} in subgoals: {self.subgoals}')
                        time_dictionnary['planning_time'] += time.time() - t
                        n_plannings += 1
                    elif replan_when_out:
                        t = time.time()
                        self.init_env(frame,track_token_history=track_token_history,n_plan_steps=n_plan_steps,argmax=argmax,pred_type=pred_type,k_samples=k_samples,beam_width=beam_width,choice=choice) # Replan all the trajectory
                        self.replaning_step = 0
                        time_dictionnary['planning_time'] += time.time() - t
                        n_plannings += 1
                        print(f'Out of path, replanned to subgoals: {self.subgoals}')
                    else:
                        print(f'Out of path detected for token: {token_frame}')
            
                # pop all tokens in the subgoals that have been reached
                # self.current_token = token_frame
                # while len(self.subgoals) > 0 and self.current_token == self.subgoals[-1]:
                #     self.subgoals.pop()

                # depile all target tokens until there is no more current token in the path to avoid cycles
                self.current_token = token_frame
                while self.current_token in self.subgoals:
                    self.subgoals.pop()

            # if we reach last subgoal, enter pure goal reaching phase
            if len(self.subgoals) == 0:
                if self.current_token == self.goal_token:
                    self.current_phase = GOAL_PHASE
                    print(f'started GOAL_PHASE')
                else:
                    # Replan
                    t = time.time()
                    self.init_env(frame,track_token_history=track_token_history,n_plan_steps=n_plan_steps,argmax=argmax,pred_type=pred_type,k_samples=k_samples,beam_width=beam_width,choice=choice) # Replan all the trajectory
                    self.replaning_step = 0
                    time_dictionnary['planning_time'] += time.time() - t
                    n_plannings += 1
                    print(f'Out of path, replanned to subgoals: {self.subgoals}')

                    if len(self.subgoals) == 0:
                        self.current_phase = GOAL_PHASE
                        print(f'started GOAL_PHASE')
                    else:
                        t = time.time()
                        if max(['sensor' in key for key in list(frame.keys())]): # ugly way to check if this is godot
                            frame['token/next/obs_representation'] = self.obs_representation_of_tokens[self.subgoals[-1]].cpu()
                            action = self.low_level_policy_subgoal._action(frame, stochastic=stochastic)
                        else:
                            if self.use_obs_representation:
                                frame['observation'] = frame['obs/partial']
                                frame['goal'] = self.obs_representation_of_tokens[self.subgoals[-1]].cpu()
                            else:
                                frame['observation'] = frame['obs/partial']
                                frame['goal'] = torch.tensor([self.subgoals[-1]])
                            action = self.low_level_policy_subgoal._action(frame, stochastic=stochastic)
                        time_dictionnary['inference_time'] += time.time() - t
                        n_inferences += 1
            else:
                t = time.time()
                if max(['sensor' in key for key in list(frame.keys())]): # ugly way to check if this is godot
                    frame['token/next/obs_representation'] = self.obs_representation_of_tokens[self.subgoals[-1]].cpu()
                    action = self.low_level_policy_subgoal._action(frame, stochastic=stochastic)
                else:
                    if self.use_obs_representation:
                        frame['observation'] = frame['obs/partial']
                        frame['goal'] = self.obs_representation_of_tokens[self.subgoals[-1]].cpu()
                    else:
                        frame['observation'] = frame['obs/partial']
                        frame['goal'] = torch.tensor([self.subgoals[-1]])
                    action = self.low_level_policy_subgoal._action(frame, stochastic=stochastic)
                time_dictionnary['inference_time'] += time.time() - t
                n_inferences += 1
                
        if self.current_phase == GOAL_PHASE:
            
            # check if we are in a new area out of goal token area
            if (self.current_token != token_frame) and (add_when_out_goal or replan_when_out_goal):
                print(f'Out of goal_token - back to SUBOAL_PHASE - former token: {self.current_token}, new token: {token_frame}, remaining tokens: {self.subgoals}')
                self.subgoals_history.append(token_frame)

                if add_when_out_goal:
                    t = time.time()
                    self.subgoals.append(self.current_token) # Add back last current token to backtrack
                    self.replaning_step = 0
                    self.current_phase = SUBGOALS_PHASE
                    time_dictionnary['planning_time'] += time.time() - t
                    n_plannings += 1
                    print(f'Out of path, added {token_frame} in subgoals: {self.subgoals}')
                elif replan_when_out_goal:
                    t = time.time()
                    self.init_env(frame,track_token_history=track_token_history,n_plan_steps=n_plan_steps,argmax=argmax) # Replan all the trajectory
                    self.replaning_step = 0
                    self.current_phase = SUBGOALS_PHASE
                    time_dictionnary['planning_time'] += time.time() - t
                    n_plannings += 1
                    print(f'Out of path, replanned to subgoals: {self.subgoals}')
                else:
                    print(f'Out of path detected for token: {token_frame}')

                # Perform subgoal policy inference
                t = time.time()
                if max(['sensor' in key for key in list(frame.keys())]): # ugly way to check if this is godot
                    frame['token/next/obs_representation'] = self.obs_representation_of_tokens[self.subgoals[-1]].cpu()
                    action = self.low_level_policy_subgoal._action(frame, stochastic=stochastic)
                else:
                    if self.use_obs_representation:
                        frame['observation'] = frame['obs/partial']
                        frame['goal'] = self.obs_representation_of_tokens[self.subgoals[-1]].cpu()
                    else:
                        frame['observation'] = frame['obs/partial']
                        frame['goal'] = torch.tensor([self.subgoals[-1]])
                    action = self.low_level_policy_subgoal._action(frame, stochastic=stochastic)
                time_dictionnary['inference_time'] += time.time() - t
                n_inferences += 1
            else:
                t = time.time()
                # Perform goal policy inference
                if max(['sensor' in key for key in list(frame.keys())]): # ugly way to check if this is godot
                    # frame['sensor/goal_absolute_position'] = frame['sensor/goal_absolute_position'][[0,2]]
                    action = self.low_level_policy_goal._action(frame, stochastic=stochastic)
                else:
                    # obs_goal = torch.cat((frame['goal'], frame['obs/complete']))
                    action = self.low_level_policy_goal._action(frame, stochastic=stochastic)
                time_dictionnary['inference_time'] += time.time() - t
                n_inferences += 1
        
        if n_plannings == 0: # if we performed no planning
            self.replaning_step += 1
        infos = {'time_dictionnary': time_dictionnary, 'n_plannings': n_plannings, 'n_inferences': n_inferences}
        if plot_infos:
            return action, infos
        else:
            return action