import numpy as np

from rlkit.envs.ant import AntEnv
from . import register_env

from PIL import Image
import mujoco_py
import os

def generate_png(hfield,filename, size=(16, 16)):
    assert (0 <= hfield).all() and (hfield <= 1).all()
    
    # 生成随机的黑白像素（0 或 255）
    data = (hfield*256).astype('int')#np.random.choice([0, 255], size=size, p=[0.5, 0.5]).astype(np.uint8)
    
    # 创建图像
    img = Image.fromarray(data, mode="L")
    
    # 保存为 PNG
    img.save(filename)
    # print(f"Image saved as {filename}")

@register_env('ant-cir-forever')
class AntCirForeverEnv(AntEnv):

    def __init__(self, task={}, n_tasks=2, forward_backward=False, randomize_tasks=True, **kwargs):
        # self._goal=0.
        self.t=0
        self.freq=100*5*1000
        
        #switch
        self.single_len=40#one side len
        self.refresh_range=self.single_len/2#left/right half
        self.switch_dist=self.single_len*50#len=20, 100 iters
        #debug
        # self.single_len=2#one side len
        # self.refresh_range=self.single_len/2#left/right half
        # self.switch_dist=self.single_len*5#len=20, 100 iters
        
        
        self.map={-self.switch_dist:0,0:0,self.switch_dist:0}#pos:height
        self.x_offset=0#pos
        self.y_offset=0
        self.z_offset=0
        self.slope=0#0.2#0.02#1#0.04#0.02#real slope
        self.res=16
        
        self.forward_backward = forward_backward
        super(AntCirForeverEnv, self).__init__()
        
        generate_png(np.zeros((self.res,self.res)),os.path.join(os.path.dirname(__file__), "multi_assets/"+str(self.timestamp), "current.png"))

    def step(self, action):
        self.t+=1
        # print('db0331 step',self.sim.data.qpos[:3])
        if abs(self.sim.data.qpos[0])>100:
            print('step, x very large',self.sim.data.qpos[0])
            raise
        elif abs(self.sim.data.qpos[1])>100:
            print('step, y very large',self.sim.data.qpos[1])
            raise
        elif abs(self.sim.data.qpos[2])>100:
            print('step, z very large',self.sim.data.qpos[2])
            raise
        torso_xyz_before = np.array(self.get_body_com("torso"))

        # direct = (np.cos(self._goal), np.sin(self._goal))
        goal=(self.t/self.freq)*2*np.pi
        direct = (np.cos(goal), np.sin(goal))

        self.do_simulation(action, self.frame_skip)
        torso_xyz_after = np.array(self.get_body_com("torso"))
        torso_velocity = torso_xyz_after - torso_xyz_before
        forward_reward = np.dot((torso_velocity[:2]/self.dt), direct)

        ctrl_cost = .5 * np.square(action).sum()
        contact_cost = 0.5 * 1e-3 * np.sum(
            np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
        survive_reward = 1.0
        reward = forward_reward - ctrl_cost - contact_cost + survive_reward
        state = self.state_vector()
        notdone = True#np.isfinite(state).all() and state[2] >= 0.2 and state[2] <= 1.0
        done = not notdone
        ob = self._get_obs()
        # print('db0320 step2')
        
        self._update_terrain(torso_xyz_after)
        # print('db0320 step3')
        
        return ob, reward, done, dict(
            reward_forward=forward_reward,
            reward_ctrl=-ctrl_cost,
            reward_contact=-contact_cost,
            reward_survive=survive_reward,
            torso_velocity=torso_velocity,
        )

    # def sample_tasks(self, num_tasks):
    #     if self.forward_backward:
    #         assert num_tasks == 2
    #         velocities = np.array([0., np.pi])
    #     else:
    #         velocities = np.random.uniform(0., 2.0 * np.pi, size=(num_tasks,))
    #     tasks = [{'goal': velocity} for velocity in velocities]
    #     return tasks

    def _update_terrain(self, position):
        new_qpos=self.sim.data.qpos
        if position[:2][0]>self.refresh_range:
            print('terrain left',position[0],self.x_offset)
            new_qpos[0]-=self.refresh_range
            self.x_offset+=self.refresh_range
        elif position[:2][0]<-self.refresh_range:
            print('terrain right',position[0],self.x_offset)
            new_qpos[0]+=self.refresh_range
            self.x_offset-=self.refresh_range
        elif abs(position[:2][1])>self.refresh_range:
            print('terrain y axis',position[0],self.x_offset)
            pass
        else:
            return
        # new_qpos[1]=0
        new_hfield=self.construct_hfield(self.x_offset)
        
        # print(self.env.sim.model.hfield_data)
        #self.sim.model.hfield_data[:] = new_hfield.reshape(-1)#np.random.uniform(high=.2, low=.1, size=np.shape(self.env.sim.model.hfield_data))
        old_z_offset=self.z_offset
        self.z_offset=new_hfield.min()
        old_y_offset=self.y_offset
        self.y_offset=new_qpos[1]+old_y_offset
        new_hfield=(new_hfield-self.z_offset)/(2*self.single_len)
        new_qpos[2]+=-self.z_offset+old_z_offset
        new_qpos[1]+=-self.y_offset+old_y_offset

        generate_png(new_hfield,os.path.join(os.path.dirname(__file__), "multi_assets/"+str(self.timestamp), "current.png"))
        self._initialize_simulation()
        
        self.set_state(new_qpos,self.sim.data.qvel)
        
    def update_id(self,id):
        if id+self.switch_dist in self.map:
            prev=self.map[id+self.switch_dist]-self.map[id+2*self.switch_dist]
            
            #-1,0,1
            # rd=np.random.randint(-1,2)
            # while self.slope!=0 and rd*self.slope*self.switch_dist==prev:
            #     rd=np.random.randint(-1,2)
            
            #-1,1
            # rd=np.random.randint(0,2)*2-1
            # while self.slope!=0 and rd*self.slope*self.switch_dist==prev:
            #     rd=np.random.randint(0,2)*2-1
            if prev==0:
                rd=1
            else:
                rd=-int(prev/(self.slope*self.switch_dist))
                
            self.map[id]=self.map[id+self.switch_dist]+rd*self.slope*self.switch_dist
        elif id-self.switch_dist in self.map:
            prev=self.map[id-self.switch_dist]-self.map[id-2*self.switch_dist]
            
            #-1,0,1
            # rd=np.random.randint(-1,2)
            # while self.slope!=0 and rd*self.slope*self.switch_dist==prev:
            #     rd=np.random.randint(-1,2)
            
            #-1,1
            # rd=np.random.randint(0,2)*2-1
            # while self.slope!=0 and rd*self.slope*self.switch_dist==prev:
            #     rd=np.random.randint(0,2)*2-1
            if prev==0:
                rd=1
            else:
                rd=-int(prev/(self.slope*self.switch_dist))
            
            self.map[id]=self.map[id-self.switch_dist]+rd*self.slope*self.switch_dist
        else:
            raise
        print('id',id,id in self.map,self.map)
        # print('update id ok')
    
    def construct_hfield(self,center):
        left=center-self.refresh_range*2
        right=center+self.refresh_range*2
        if right<=(left//self.switch_dist+1)*self.switch_dist:#left//self.switch_dist==right//self.switch_dist:
            left_id=left//self.switch_dist*self.switch_dist
            if left_id not in self.map:
                self.update_id(left_id)
            right_id=left_id+self.switch_dist
            if right_id not in self.map:
                self.update_id(right_id)
            x=left+np.arange(self.res)/self.res*(right-left)
            hfield_data=self.map[left_id]+(x-left_id)/(right_id-left_id)*(self.map[right_id]-self.map[left_id])
            hfield_data=np.repeat(hfield_data[:,np.newaxis],self.res,axis=1)
        else:
            left_id=left//self.switch_dist*self.switch_dist
            if left_id not in self.map:
                self.update_id(left_id)
            mid_id=left_id+self.switch_dist
            if mid_id not in self.map:
                self.update_id(mid_id)
            right_id=mid_id+self.switch_dist
            if right_id not in self.map:
                self.update_id(right_id)
            x=left+np.arange(self.res)/self.res*(right-left)
            split=np.where(x>=mid_id)[0].min()
            hfield_data=self.map[left_id]+(x-left_id)/(mid_id-left_id)*(self.map[mid_id]-self.map[left_id])
            hfield_data[split:]=self.map[mid_id]+(x[split:]-mid_id)/(right_id-mid_id)*(self.map[right_id]-self.map[mid_id])
            hfield_data=np.repeat(hfield_data[:,np.newaxis],self.res,axis=1)
        return hfield_data
    
    def get_all_task_idx(self):
        return [0]*10

    def reset_task(self,id):
        return
    
    def reset(self):
        print('db0331 reset',self.sim.data.qpos[:3],'t',self.t)
        print('resetx',self.sim.data.qpos[0]+self.x_offset)
        print('resety',self.sim.data.qpos[1]+self.y_offset)
        if abs(self.sim.data.qpos[0])>100:
            print('reset, x very large')
            raise
        elif abs(self.sim.data.qpos[1])>100:
            print('reset, y very large')
            raise
        elif abs(self.sim.data.qpos[2])>100:
            print('reset, z very large',self.sim.data.qpos[2])
            raise
        
        qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1)
        qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
        qpos[:2]=self.sim.data.qpos[:2]
        qpos[2]=self.sim.data.qpos[2]+0.75
        # print('db0320 reset2')
        super().reset()
        self.set_state(qpos, qvel)
        # print('db0320 reset3')
        return self._get_obs()
    
    def _initialize_simulation(self):
        self.model = mujoco_py.load_model_from_path(self.fullpath)
        self.sim = mujoco_py.MjSim(self.model)
        self.data = self.sim.data