
from dataclasses import dataclass
import torch
import pickle

from TSProblemDef import get_random_problems, augment_xy_data_by_8_fold


@dataclass
class Reset_State:
    problems: torch.Tensor
    # shape: (batch, problem, 2)
    tunnels: torch.Tensor
    


@dataclass
class Step_State:
    BATCH_IDX: torch.Tensor
    POMO_IDX: torch.Tensor
    # shape: (batch, pomo)
    current_node: torch.Tensor = None
    # shape: (batch, pomo)
    ninf_mask: torch.Tensor = None
    # shape: (batch, pomo, node)


class TSPEnv:
    def __init__(self, device,env_params):

        # Const @INIT
        ####################################
        self.env_params = env_params
        self.problem_size = env_params['problem_size']
        self.pomo_size = env_params['pomo_size']
        self.tunnel_size = env_params['tunnel_size']
        self.device = device

        # Const @Load_Problem
        ####################################
        self.batch_size = None
        self.BATCH_IDX = None
        self.POMO_IDX = None
        # IDX.shape: (batch, pomo)
        self.problems = None
        # shape: (batch, node, node)
        self.sources = None
        self.targets = None

        # Dynamic
        ####################################
        self.selected_count = None
        self.current_node = None
        self.last_current_node = None
        self.mode=None
        # shape: (batch, pomo)
        self.selected_node_list = None
        # shape: (batch, pomo, 0~problem)
        

        self.regret_mask_matrix=None
        self.add_mask_matrix=None

        self.use_load_data=False
        self.data=None
        self.offset=0



    def load_problems(self, batch_size, aug_factor=1):
        self.batch_size = batch_size

        if self.use_load_data:
            self.problems = self.data[self.offset:self.offset + batch_size]
            self.offset += batch_size
        else:
            self.problems,self.sources,self.targets = get_random_problems(batch_size, self.problem_size,self.tunnel_size)
        self.tunnels = torch.cat([self.sources.unsqueeze(-1),self.targets.unsqueeze(-1)],dim=2)

        # problems.shape: (batch, problem, 2)
        if aug_factor > 1:
            if aug_factor == 8:
                self.batch_size = self.batch_size * 8
                self.problems = augment_xy_data_by_8_fold(self.problems)
                # shape: (8*batch, problem, 2)
            else:
                raise NotImplementedError

        self.BATCH_IDX = torch.arange(self.batch_size)[:, None].expand(self.batch_size, self.pomo_size)
        self.POMO_IDX = torch.arange(self.pomo_size)[None, :].expand(self.batch_size, self.pomo_size)

    def reset(self):
        self.selected_count = torch.zeros((self.batch_size, self.pomo_size), dtype=torch.long)
        self.mode=torch.full((self.batch_size, self.pomo_size),1)
        self.current_node = None
        self.last_current_node = None
        # shape: (batch, pomo)
        self.selected_node_list = torch.zeros((self.batch_size, self.pomo_size, 0), dtype=torch.long)
        self.T_selected_node_list = torch.zeros((self.batch_size, self.pomo_size, 0), dtype=torch.long)
        # shape: (batch, pomo, 0~problem)
        self.last1_crossbool_index = torch.zeros((self.batch_size,self.pomo_size)).bool()
        self.last2_crossbool_index = torch.zeros((self.batch_size,self.pomo_size)).bool()
        
        # CREATE STEP STATE
        self.step_state = Step_State(BATCH_IDX=self.BATCH_IDX, POMO_IDX=self.POMO_IDX)
        self.step_state.ninf_mask = torch.zeros((self.batch_size, self.pomo_size, self.problem_size+1))
        self.step_state.ninf_mask[:, :,self.problem_size] = float('-inf')
        # shape: (batch, pomo, problem)
        self.fixed_edge_length = self._get_fixed_length()
        reward = None
        done = False
        return Reset_State(self.problems,self.tunnels), reward, done

    def pre_step(self):
        reward = None
        done = False
        return self.step_state, reward, done

    def step(self, selected):
        # selected.shape: (batch, pomo)

        action0_bool_index=((self.mode==0) & (selected!=self.problem_size)) #No regret
        action1_bool_index=((self.mode==0) & (selected==self.problem_size)) # regret
        action2_bool_index=self.mode==1
        action3_bool_index=self.mode==2

        action0_index = torch.nonzero(action0_bool_index)
        action1_index = torch.nonzero(action1_bool_index)
        action2_index = torch.nonzero(action2_bool_index)
        action3_index = torch.nonzero(action3_bool_index)


        T_selected = self.__connect_source_target_city(selected)
        crosstunnel_bool_index = (T_selected != selected)
        crosstunnel0_index = torch.nonzero(crosstunnel_bool_index&action0_bool_index)
        crosstunnel1_index = torch.nonzero(crosstunnel_bool_index&action1_bool_index)
        crosstunnel2_index = torch.nonzero(crosstunnel_bool_index&action2_bool_index)
        crosstunnel3_index = torch.nonzero(crosstunnel_bool_index&action3_bool_index)

        first_step=self.current_node is None
        second_step = (self.current_node is not None) and (self.last_current_node is None)

        # 1 change self.current_node and self.last_current_node
        if first_step:
            self.cross_current_node = T_selected
            self.current_node = selected
        elif second_step:
            self.cross_last_current_node = self.cross_current_node.clone()
            self.cross_current_node = T_selected
            self.last_current_node=self.current_node.clone()
            self.current_node = selected
        else:
            _ = self.last_current_node[action1_index[:, 0], action1_index[:, 1]].clone()
            temp_last_current_node_action2 = self.last_current_node[action2_index[:, 0], action2_index[:, 1]].clone()
            cross_temp_last_current_node_action2 = self.cross_last_current_node[crosstunnel2_index[:,0],crosstunnel2_index[:,1]].clone()
            self.cross_last_current_node = self.cross_current_node.clone()
            self.last_current_node=self.current_node.clone()
            self.current_node = selected.clone()
            self.current_node[action1_index[:, 0], action1_index[:, 1]] = _.clone()
            self.cross_current_node = self.__connect_source_target_city(self.current_node)
            current_crossbool_index = (self.cross_current_node != self.current_node)
            crosstunnel2_index = torch.nonzero(current_crossbool_index&action2_bool_index)


        
        # 2 change self.step_state.ninf_mask
        # action0
        self.step_state.ninf_mask[action0_index[:, 0], action0_index[:, 1], self.current_node[action0_index[:, 0], action0_index[:, 1]]] = float('-inf')
        self.step_state.ninf_mask[crosstunnel0_index[:,0],crosstunnel0_index[:,1],T_selected[crosstunnel0_index[:,0],crosstunnel0_index[:,1]]] = float('-inf')
        # action1
        self.step_state.ninf_mask[action1_index[:, 0], action1_index[:, 1], selected[action1_index[:, 0], action1_index[:, 1]]] = float('-inf')
        # action2
        
        if not (first_step or second_step):
            self.step_state.ninf_mask[action2_index[:, 0], action2_index[:, 1], temp_last_current_node_action2] = float(0)
            self.step_state.ninf_mask[crosstunnel2_index[:, 0], crosstunnel2_index[:, 1], cross_temp_last_current_node_action2] = float(0)
        self.step_state.ninf_mask[action2_index[:, 0], action2_index[:, 1], self.current_node[action2_index[:, 0], action2_index[:, 1]]] = float('-inf')
        self.step_state.ninf_mask[crosstunnel2_index[:,0],crosstunnel2_index[:,1],T_selected[crosstunnel2_index[:,0],crosstunnel2_index[:,1]]] = float('-inf')
        # action3
        self.step_state.ninf_mask[action3_index[:, 0], action3_index[:, 1], self.current_node[action3_index[:, 0], action3_index[:, 1]]] = float('-inf')
        self.step_state.ninf_mask[crosstunnel3_index[:,0],crosstunnel3_index[:,1],T_selected[crosstunnel3_index[:,0],crosstunnel3_index[:,1]]] = float('-inf')
        self.step_state.ninf_mask[action3_index[:, 0], action3_index[:, 1], self.problem_size] = float(0)

        self.selected_node_list = torch.cat((self.selected_node_list, selected[:, :, None]), dim=2)
        self.T_selected_node_list = torch.cat((self.T_selected_node_list, T_selected[:, :, None]), dim=2)

        # 4 change self.selected_count
        self.selected_count[action0_bool_index | action2_bool_index | action3_bool_index] = self.selected_count[action0_bool_index | action2_bool_index | action3_bool_index] +1
        self.selected_count[action1_bool_index] = self.selected_count[action1_bool_index] -1

        ####################
        # done: shape: (batch, pomo)
        done = (self.selected_count == self.problem_size - self.tunnel_size)

        # 5 change self.mode
        self.mode[action1_bool_index] = 1
        self.mode[action2_bool_index] = 2
        self.mode[action3_bool_index] = 0
        self.mode[done] = 3

        done_idex= torch.nonzero(done)

        self.step_state.ninf_mask[done_idex[:, 0], done_idex[:, 1], self.current_node[done_idex[:, 0], done_idex[:, 1]]] = float(0)
        self.step_state.ninf_mask[done_idex[:, 0], done_idex[:, 1], self.problem_size] = float("-inf")
        
        #self.step_state.current_node = self.current_node
        self.step_state.current_node = self.cross_current_node
        
        # returning values
        if done.all():
            reward = -self._get_travel_distance()  # note the minus sign!
        else:
            reward = None

        return self.step_state, reward, done.all()

    def _get_travel_distance(self):

        m1 = (self.selected_node_list==self.problem_size)
        m2 = (m1.roll(dims=2, shifts=-1) | m1)
        m3 = m1.roll(dims=2, shifts=1)
        m4 = ~(m2|m3)
        m5 = (self.selected_node_list == self.selected_node_list.roll(dims=2,shifts=1))

        #print(m1[0,0])
        #print(m2[0,0])
        #print(m3[0,0])
        #print(m4[0,0])
        #print(m5[0,0])

        selected_node_list_right = self.T_selected_node_list.roll(dims=2, shifts=1)
        selected_node_list_right2 = self.T_selected_node_list.roll(dims=2, shifts=3)
        self.regret_mask_matrix = m1
        self.add_mask_matrix = (~m2)

        travel_distances = torch.zeros((self.batch_size, self.pomo_size))

        for t in range(self.selected_node_list.shape[2]):
            add1_index = (m4[:,:,t].unsqueeze(2)).nonzero()

            add3_index = (m3[:,:,t].unsqueeze(2)).nonzero()
            noadd_index = (m5[:,:,t].unsqueeze(2)).nonzero()

            travel_distances[add1_index[:,0],add1_index[:,1]] = travel_distances[add1_index[:,0],add1_index[:,1]].clone()+((self.problems[add1_index[:,0],self.T_selected_node_list[add1_index[:,0],add1_index[:,1],t],:]-self.problems[add1_index[:,0],self.selected_node_list[add1_index[:,0],add1_index[:,1],t],:])**2).sum(1).sqrt()+((self.problems[add1_index[:,0],self.selected_node_list[add1_index[:,0],add1_index[:,1],t],:]-self.problems[add1_index[:,0],selected_node_list_right[add1_index[:,0],add1_index[:,1],t],:])**2).sum(1).sqrt()

            travel_distances[add3_index[:,0],add3_index[:,1]] = travel_distances[add3_index[:,0],add3_index[:,1]].clone()+((self.problems[add3_index[:,0],self.T_selected_node_list[add3_index[:,0],add3_index[:,1],t],:]-self.problems[add3_index[:,0],self.selected_node_list[add3_index[:,0],add3_index[:,1],t],:])**2).sum(1).sqrt()+((self.problems[add3_index[:,0],self.selected_node_list[add3_index[:,0],add3_index[:,1],t],:]-self.problems[add3_index[:,0],selected_node_list_right2[add3_index[:,0],add3_index[:,1],t],:])**2).sum(1).sqrt()
        
            travel_distances[noadd_index[:,0],noadd_index[:,1]] = travel_distances[noadd_index[:,0],noadd_index[:,1]].clone()-2 * ((self.problems[noadd_index[:,0],self.selected_node_list[noadd_index[:,0],noadd_index[:,1],t],:]-self.problems[noadd_index[:,0],self.T_selected_node_list[noadd_index[:,0],noadd_index[:,1],t],:])**2).sum(1).sqrt()

            #print('Part1',((self.problems[add1_index[:,0],self.selected_node_list[add1_index[:,0],add1_index[:,1],t],:]-self.problems[add1_index[:,0],self.T_selected_node_list[add1_index[:,0],add1_index[:,1],t],:])**2).sum(1).sqrt())
            #print('Part2',((self.problems[add1_index[:,0],self.selected_node_list[add1_index[:,0],add1_index[:,1],t],:]-self.problems[add1_index[:,0],selected_node_list_right[add1_index[:,0],add1_index[:,1],t],:])**2).sum(1).sqrt())
            #print('toatl',travel_distances[add1_index[:,0],add1_index[:,1]])
        travel_distances = travel_distances - self.fixed_edge_length
        #print(self.fixed_edge_length)
        return travel_distances

    def __connect_source_target_city(self, selected_idx_mat):
        bsz,len_st = self.sources.shape
        _,m = selected_idx_mat.shape
        Gsource = self.sources.unsqueeze(1).expand(bsz,m, len_st).long().to(self.device)
        Gtarget = self.targets.unsqueeze(1).expand(bsz,m, len_st).long().to(self.device)
        Ginput = selected_idx_mat.unsqueeze(2).expand(bsz,m, len_st).to(self.device)
        source_match = (Gsource == Ginput).nonzero(as_tuple=True)
        target_match = (Gtarget == Ginput).nonzero(as_tuple=True)
        EXSinput = Ginput.clone()  
        EXGinput = Ginput.clone() 
        EXSinput[source_match] = Gtarget[source_match]
        EXGinput[target_match] = Gsource[target_match]
        Minus = torch.sum(EXGinput+EXSinput-2*Ginput,dim=2)
        output = selected_idx_mat + Minus
        return output
    
    def _get_fixed_length(self):
        arange_vec = torch.arange(self.batch_size)
        fixed_length = torch.zeros(self.batch_size)
        for j in range(self.tunnel_size):
            end_node = self.problems[arange_vec,self.sources[:,j].long(),:]
            start_node = self.problems[arange_vec,self.targets[:,j].long(),:]
            fixed_length += torch.sum((start_node-end_node)**2,dim=1)**0.5
        return fixed_length.unsqueeze(-1).expand(self.batch_size,self.pomo_size)
