import numpy as np
import hydra
import isaacgym
import gym
from isaacgym import gymtorch
import torch


from allegro_gym.envs.table_env import Allegro
from isaacgym.torch_utils import *
from plots.data_logging import Log, ListOfLogs, NoLog, SimpleLog 

#@hydra.main(version_base = '1.2', config_path = 'configs', config_name = 'run_env')
def main():
        #allegroenv=Allegro(configs)
        allegroenv=gym.make('allegro_gym:allegrogym-v0')
        done=False
        #allegroenv.gym.prepare_sim(allegroenv.sim)
        #allegroenv.next_update_time = 0.1
        state=allegroenv.reset()
        state=to_torch(state, dtype=torch.float, device='cpu')
        print("state")
        print(state)
        dofs=allegroenv.get_dof_count()
        cnt=0
        
        #cur_targets = torch.zeros((allegroenv.num_envs,allegroenv.num_dofs), dtype=torch.float, device='cpu')
        #target=torch.add(state,cur_targets)
        #print(target)
        trajectory=np.load('trajectory_3.npy')
        new_trajectory=np.zeros((len(trajectory),36,22))
        print(trajectory.shape)
        for i in range(len(trajectory)):
                        for j in range(36):
                                new_trajectory[i][j]=np.concatenate((trajectory[i][0:6],trajectory[i][6:10],trajectory[i][17:],trajectory[i][10:14],trajectory[i][14:17]))
                
        actions=to_torch(new_trajectory,dtype=torch.float, device='cpu')


        while not done:
                #action=np.random.uniform(-1.4,1.4)
               
                #actions=np.array([action]*dofs)
                #action_tensor=gymtorch.Tensor(actions)
                #actions=trajectory[cnt]
                target=actions[cnt]
                print("True Actions")
                #print(target)
                nextstate,reward,done,info=allegroenv.step(target)
                cnt=cnt+1
                #nextstate=to_torch(nextstate, dtype=torch.float, device='cpu')
                #target=torch.add(nextstate,cur_targets)
                #print("Lower limits")
                #print(allegroenv.get_lower_limits())
                #print("Upper limits")
                #print(allegroenv.get_upper_limits())
                #if any(target[5]>=allegroenv.get_upper_limits()[5]):
                        #target[5]-=0.025
                        #print(allegroenv.get_mids())
                        #print("More")
                #elif any(target[5]<allegroenv.get_lower_limits()[5]):
                        #target[5]+=0.025
                        #print(allegroenv.get_mids())
                        #print("Less")
                #else :
                        #target[5]+=0.025
                
                if cnt==len(trajectory):
                        done=True

                #allegroenv.frame = allegroenv.frame + 1
                #cur_targets=torch.add(nextstate, cur_targets)
        log=allegroenv.return_log()
        if log is not None:
                        if isinstance(log, ListOfLogs):
                                log.finish_log() 
                        else:
                                log.save()
        
                
if __name__=='__main__':
        main()
