import random
import numpy as np
import gymnasium as gym
import torch.nn as nn
import torch
import copy

class GridShootingVSRandom:
    single_action_space=gym.spaces.Discrete(9)
    single_observation_space=gym.spaces.Box(np.zeros(56),np.ones(56)*9)
    def __init__(self):
        self.env=GridShooting()
    
    def reset(self):
        self.episode_return=0
        return self.env.reset()[0]
    
    def step(self,action):
        legal_actions=self.env.legal_actions()[1]
        action2=random.sample(list(np.where(legal_actions==1)[0]),1)[0]
        state,reward,done,info=self.env.step(action,action2)
        if 'win' in info:
            info={'win':info['win'],'final_info':[{'episode':{'r':info['episode_returns'][0],'win':info['win']==0,'l':self.env.frame_no,'star':info['star'],'shoot':info['shoot']}}]}
        return state[0],reward[0],done,info

    def close(self):
        self.env.close()
        
class GridShootingVSGiven:
    single_action_space=gym.spaces.Discrete(9)
    single_observation_space=gym.spaces.Box(np.zeros(56),np.ones(56)*9)
    def __init__(self,model_dir='/home/wjh/myh/cleanrl-master/useful_models/grid_shooting_onlyshoot/dqn_env_1000000.cleanrl_model'):
        self.env=GridShooting()
        
        self.model = QNetwork(self)
        self.model.load_state_dict(torch.load(model_dir))
    
    def reset(self):
        self.episode_return=0
        self.obs=self.env.reset()
        return self.obs[0]
    
    def step(self,action):
        q_values = self.model(torch.Tensor(self.obs[1]))
        action2 = torch.argmax(q_values, dim=-1).numpy()
        self.obs,reward,done,info=self.env.step(action,action2)
        if 'win' in info:
            info={'win':info['win'],'final_info':[{'episode':{'r':info['episode_returns'][0],'win':info['win']==0,'l':self.env.frame_no}}]}
        return self.obs[0],reward[0],done,info

    def close(self):
        self.env.close()
        
class GridShooting:
    SIZE=9
    MAX_CD=5
    
    SHOOTING_REWARD=10
    MOVE_REWARD=-1
    STAR_REWARD=1
    
    def __init__(self):
        self.agents={}
        self.reset()
        self.actions={}
    
    def reset(self):
        #self.agents={0:[0,0,0],1:[self.SIZE-1,self.SIZE-1,0]}
        self.agents={0:[random.randint(0,self.SIZE-1),random.randint(0,self.SIZE-1),0],1:[random.randint(0,self.SIZE-1),random.randint(0,self.SIZE-1),0]}
        
        self.star=np.random.randint(0,self.SIZE,2)
        self.frame_no=0
        self.episode_returns=[0,0]
        
        self.analyze_star=0
        self.analyze_shoot=0
        
        return [self.get_state(0),self.get_state(1)]

    def state_dim(self):
        return 6*self.SIZE+2

    def action_size(self):
        return 9
    
    def get_state(self,team_id):
        state=[]
        
        enemy_team_id=team_id^1
        agents_pos={}
        for id in range(2):
            agent_pos=[0. for _ in range(self.SIZE*2)]
            agent_pos[self.agents[id][0]]=1
            agent_pos[self.agents[id][1]+self.SIZE]=1
            agents_pos[id]=agent_pos
        state.extend(agents_pos[team_id])
        state.extend(agents_pos[enemy_team_id])
        
        star=self.star
        star_pos=[0. for _ in range(self.SIZE*2)]
        star_pos[star[0]]=1
        star_pos[star[1]+self.SIZE]=1
        state.extend(star_pos)
        
        state.append(self.agents[team_id][2]/float(self.MAX_CD))
        state.append(self.agents[enemy_team_id][2]/float(self.MAX_CD))
        return np.array(state)

    def legal_actions(self):
        legal_actions=[]
        for team_id in range(2):
            if self.agents[team_id][2]>0:
                legal_actions.append([1]*5+[0]*4)
            else:
                legal_actions.append([1]*9)
        return np.array(legal_actions)

    def check_shooting(self,team_id,shooting_action):
        enemy_team_id=1 if team_id==0 else 0
        shooting_action-=5
        
        def check_pos(equal_idx,is_less_equal):
            compare_idx=1 if equal_idx==0 else 0
            if self.agents[team_id][equal_idx]==self.agents[enemy_team_id][equal_idx]:
                if is_less_equal and self.agents[team_id][compare_idx]<=self.agents[enemy_team_id][compare_idx]:
                    return True
                if not is_less_equal and self.agents[team_id][compare_idx]>=self.agents[enemy_team_id][compare_idx]:
                    return True
            return False
        
        if shooting_action==0:
            return check_pos(1,True)
        elif shooting_action==1:
            return check_pos(0,True)
        elif shooting_action==2:
            return check_pos(1,False)
        elif shooting_action==3:
            return check_pos(0,False)
        
        return False
    
    def move(self,team_id,action):
        if action==0:
            return
        if action==1:
            self.agents[team_id][0]=min(self.agents[team_id][0]+1,self.SIZE-1)
        elif action==2:
            self.agents[team_id][1]=min(self.agents[team_id][1]+1,self.SIZE-1)
        elif action==3:
            self.agents[team_id][0]=max(self.agents[team_id][0]-1,0)
        elif action==4:
            self.agents[team_id][1]=max(self.agents[team_id][1]-1,0)
        self.agents[team_id][2]=max(self.agents[team_id][2]-1,0)
    
    def step(self,team_0_action_in,team_1_action_in):
        team_0_action=copy.deepcopy(team_0_action_in)
        team_1_action=copy.deepcopy(team_1_action_in)
        self.frame_no+=1
        epi_done=False
        rewards=[self.MOVE_REWARD,self.MOVE_REWARD]
        self.episode_returns[0]+=self.MOVE_REWARD
        self.episode_returns[1]+=self.MOVE_REWARD
        
        if team_0_action>=5:
            if self.agents[0][2]>0:
                team_0_action=0
        if team_1_action>=5:
            if self.agents[1][2]>0:
                team_1_action=0
        
        if team_0_action>=5:
            if self.check_shooting(0,team_0_action):
                epi_done=True
                rewards[0]+=self.SHOOTING_REWARD
                self.analyze_shoot+=1
                rewards[1]-=self.SHOOTING_REWARD
                self.episode_returns[0]+=self.SHOOTING_REWARD
                self.episode_returns[1]-=self.SHOOTING_REWARD
            self.agents[0][2]=self.MAX_CD
        if team_1_action>=5:
            if self.check_shooting(1,team_1_action):
                epi_done=True
                rewards[0]-=self.SHOOTING_REWARD
                rewards[1]+=self.SHOOTING_REWARD
                self.episode_returns[0]-=self.SHOOTING_REWARD
                self.episode_returns[1]+=self.SHOOTING_REWARD
            self.agents[1][2]=self.MAX_CD
            
        def calc_win_team(final_rewards):
            if final_rewards[0]==final_rewards[1]:
                return -1
            elif final_rewards[0]>final_rewards[1]:
                return 0
            else:
                return 1
        
        if epi_done:
            # print('in env',self.agents,self.star,team_0_action_in,team_1_action_in,'shoot',rewards)
            return [self.get_state(0),self.get_state(1)],rewards,epi_done,{'episode_returns':self.episode_returns,'win':calc_win_team(self.episode_returns),'star':self.analyze_star,'shoot':self.analyze_shoot}
        
        if team_0_action<5:
            self.move(0,team_0_action)
        if team_1_action<5:
            self.move(1,team_1_action)
            
        recreate_star=False
        if self.star[0]==self.agents[0][0] and self.star[1]==self.agents[0][1]:
            recreate_star=True
            rewards[0]+=self.STAR_REWARD
            self.episode_returns[0]+=self.STAR_REWARD
            self.analyze_star+=1
        if self.star[0]==self.agents[1][0] and self.star[1]==self.agents[1][1]:
            recreate_star=True
            rewards[1]+=self.STAR_REWARD
            self.episode_returns[1]+=self.STAR_REWARD
        
        if recreate_star:
            self.star=np.random.randint(0,self.SIZE,2)
        
        if self.frame_no>=100:
            epi_done=True
        
        # print('in env',self.agents,self.star,team_0_action_in,team_1_action_in,'no shoot',rewards)
        return [self.get_state(0),self.get_state(1)],rewards,epi_done,{'episode_returns':self.episode_returns,'win':calc_win_team(self.episode_returns),'star':self.analyze_star,'shoot':self.analyze_shoot} if epi_done else {}
   
    def close(self):
        return
    
    def render(self):
        map=[['.']*self.SIZE for _ in range(self.SIZE)]
        map[self.agents[0][0]][self.agents[0][1]]='0'
        map[self.agents[1][0]][self.agents[1][1]]='1'
        map[self.star[0]][self.star[1]]='s'
        for i in range(self.SIZE):
            map[i]=''.join(map[i])
        return map
    
class QNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, env.single_action_space.n),
        )

    def forward(self, x):
        return self.network(x)