import torch
import torch.nn as nn
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float


def param_init(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.normal_(m.weight.data, 0, 0.01)
        if isinstance(m, nn.ConvTranspose2d):
            nn.init.normal_(m.weight.data, 0, 0.01)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight.data, 1)
            nn.init.constant_(m.bias.data, 0)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight.data, 0, 0.01)
            if m.bias != None:
                nn.init.constant_(m.bias.data, 0)
        elif isinstance(m, nn.BatchNorm1d):
            nn.init.constant_(m.weight.data, 1)
            nn.init.constant_(m.bias.data, 0)


class conv_layer_module(nn.Module):
    def __init__(self, in_ch, out_ch, k, s, p, bias=False):
        super(conv_layer_module, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch,
                              kernel_size=k, stride=s, padding=p, bias=bias)
        self.bat = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bat(x)
        x = self.act(x)
        return x


class conv_transposed_layer_module(nn.Module):
    def __init__(self, in_ch, out_ch, k, s, p, op, bias=False):
        super(conv_transposed_layer_module, self).__init__()
        self.conv_t = nn.ConvTranspose2d(in_channels=in_ch, out_channels=out_ch, kernel_size=k,
                                         stride=s, padding=p, output_padding=op, bias=bias)
        self.bat = nn.BatchNorm2d(out_ch)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv_t(x)
        x = self.bat(x)
        x = self.act(x)
        return x


class fc_layer_module(nn.Module):
    def __init__(self, in_dim, out_dim, bias=False):
        super(fc_layer_module, self).__init__()
        self.fc = nn.Linear(in_dim, out_dim, bias)
        self.bat = nn.BatchNorm1d(out_dim)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.fc(x)
        x = self.bat(x)
        x = self.act(x)
        return x



class FeatNet(nn.Module):
    def __init__(self):
        super(FeatNet, self).__init__()
        self.enc = nn.Sequential(
            conv_layer_module(2, 9, 3, 2, 1),
            conv_layer_module(9, 16, 3, 2, 1),
            conv_layer_module(16, 32, 3, 2, 1),
            conv_layer_module(32, 64, 3, 2, 1),
            nn.Flatten(1))
        self.proj = nn.Sequential(
            fc_layer_module(256, 256),
            nn.Linear(256, 128))
        
        param_init(self)

    def forward(self, o):
        z = F.normalize(self.proj(self.enc(o)), dim=1)
        return z



class CoordNet(nn.Module) : 
    def __init__(self) : 
        super(CoordNet, self).__init__()
        self.conv = nn.Sequential(
            conv_layer_module(2, 9, 3, 2, 1),
            conv_layer_module(9, 16, 3, 2, 1),
            conv_layer_module(16, 32, 3, 2, 1),
            conv_layer_module(32, 64, 3, 2, 1),
            nn.Flatten(1))
        self.proj = nn.Sequential(
            fc_layer_module(256, 128))
        self.head = nn.Sequential(
            nn.Linear(128, 2),
            nn.Tanh())
        
        param_init(self)
        
    def forward(self, o) : 
        z = F.normalize(self.proj(self.conv(o)), dim=1)
        c = self.head(z)
        return z, c



class RelCoordnet(nn.Module) : 
    def __init__(self) : 
        super(RelCoordnet, self).__init__()
        self.conv = nn.Sequential(
            conv_layer_module(4, 9, 3, 2, 1),
            conv_layer_module(9, 16, 3, 2, 1),
            conv_layer_module(16, 32, 3, 2, 1),
            conv_layer_module(32, 64, 3, 2, 1),
            nn.Flatten(1))
        self.proj = nn.Sequential(
            fc_layer_module(256, 128))
        self.head = nn.Sequential(
            nn.Linear(128, 2),
            nn.Tanh())
        
        param_init(self)
        
    def forward(self, o, base_c) : 
        z = F.normalize(self.proj(self.conv(torch.cat((o, base_c), dim=1))), dim=1)
        relative_c = 2*torch.tanh(self.head(z))
        return z, relative_c



class dirPolNet(nn.Module) : 
    def __init__(self) : 
        super(dirPolNet, self).__init__()
        self.habit_net = nn.Sequential(
            fc_layer_module(128*2 + 128*2 + 1 + 1, 512),
            fc_layer_module(512, 256),
            nn.Linear(256, 8))
        
        param_init(self)
        
    def forward(self, z_f, z_c, preference_f, preference_c, lambda_ft, lambda_cd) : 
        prob = self.habit_net(torch.cat((z_f, z_f, preference_f, preference_c, lambda_ft, lambda_cd), dim=1))
        return prob



class Agent(nn.Module):
    def __init__(self, batch_size, maximum_stage):
        super(Agent, self).__init__()
        self.img_size = [255, 255]
        self.view_size = [27, 27]
        
        self.eps = 0.5
        self.max_stepsize = 30
        
        self.batch_size = batch_size
        self.maximum_stage = maximum_stage
        
        # AIF prior netowrk
        self.coordnet = CoordNet()
        self.featnet = FeatNet()
    
    
    def set_hyperparams(self, thr_init, thr_control_start, thr_freq, thr_increase, 
                        lambda_control_start, lambda_freq, lambda_ft) : 
        self.thr_init = thr_init
        self.thr_control_start = thr_control_start
        self.thr_freq = thr_freq
        self.thr_increase = thr_increase
        self.lambda_control_start = lambda_control_start
        self.lambda_freq = lambda_freq
        self.lambda_ft = lambda_ft
        self.lambda_cd = 1 - self.lambda_ft 
    
    
    def init_tensors(self):
        self.detection_time = torch.zeros(self.n_landmarks, self.batch_size).to(device) 
        self.landmark_detection_mask = torch.zeros(self.n_landmarks, self.batch_size).to(device)
        self.detection_incomplete = torch.ones(self.n_landmarks, self.batch_size).bool().to(device)
        self.landmark_coords = torch.zeros(self.n_landmarks, self.batch_size, 2).to(device) 
        self.n_lambdas = self.lambda_ft.size(1)
        self.lambda_idx = torch.zeros(self.n_landmarks).long().to(device)
        self.SHT = 5*torch.ones(self.n_landmarks, self.batch_size, self.n_lambdas, 3).to(device)
        self.detect_threshold = self.thr_init
        
        
    def set_prior(self, state, stage):
        if stage == 1 : 
            self.featnet.load_state_dict(state['f_net'])
            self.prior_z_ft_all = state['preferred_z_f']
            self.coordnet.load_state_dict(state['c_net'])
            self.prior_z_cd_all = state['preferred_z_c']
            
        if stage <= self.maximum_stage :
            self.prior_z_ft = self.prior_z_ft_all[0].view(1, 128)
            self.prior_z_cd = self.prior_z_cd_all[0].view(1, 128)
        
        else : 
            self.prior_z_ft = self.prior_z_ft_all[1:] 
            self.prior_z_cd = self.prior_z_cd_all[1:]
        
        self.stage = stage
        self.n_landmarks = self.prior_z_ft.size(0)
        self.init_tensors()
        
    
    def progress_detection(self, o0, dirpolnet, end_step=False):
        self.detection_time += self.detection_incomplete
        self.n_samples = self.detection_incomplete.sum()
        o0 = o0[self.detection_incomplete]
        
        indices = self.detection_incomplete.nonzero(as_tuple=False)[:,0]
        
        D_ft, D_cd, inferred_coord,\
            embedding_ft, embedding_cd, prior_ft, prior_cd = self.calc_D(o0, indices)
        
        D0 = self.detection_check(D_ft, D_cd, inferred_coord, indices, end_step)
        
        _ = self.stepsize_calc(D0)
        
        if end_step : 
            return None
        
        lambda_ft = self.lambda_ft[indices]
        lambda_ft = lambda_ft[torch.arange(self.n_samples), self.lambda_idx[indices]].view(-1,1)
        lambda_cd = 1-lambda_ft
        
        direction_prob = dirpolnet(embedding_ft, embedding_cd,
                                   prior_ft, prior_cd,
                                   lambda_ft, lambda_cd)
        direction_prob = F.softmax(direction_prob, dim=1)
        direction = torch.argmax(direction_prob, dim=1)
        
        act_to_env = self.act_to_env_calc(direction)
        
        self.detection_incomplete = self.landmark_detection_mask == 0
        if self.detection_incomplete.sum() == 0:
            return None
        
        return act_to_env
    
    
    def calc_D(self, o, indices):
        embedding_ft = self.featnet(o[:, :2])
        selected_prior_ft = self.prior_z_ft[indices]
        D_ft = torch.norm(embedding_ft - selected_prior_ft, dim=1)**2
        
        embedding_cd, inferred_coord = self.coordnet(o[:, 1:])
        selected_prior_cd = self.prior_z_cd[indices]
        D_cd = torch.norm(embedding_cd - selected_prior_cd, dim=1)**2
            
        return D_ft, D_cd, inferred_coord, embedding_ft, embedding_cd, selected_prior_ft, selected_prior_cd
    
    
    def calc_distance(self, D_ft, D_cd, indices): 
        D_lambdas = self.lambda_ft[indices] * D_ft.unsqueeze(1) \
            + self.lambda_cd[indices] * D_cd.unsqueeze(1)
        return D_lambdas
    
    
    def detection_check(self, D_ft, D_cd, inferred_coord, indices, end_step):
        D0_lambdas = self.calc_distance(D_ft, D_cd, indices)
        D0_current_lambda = D0_lambdas[torch.arange(self.n_samples), self.lambda_idx[indices]]
        
        SHT_ = self.SHT[self.detection_incomplete] 
        min_D_case = (D0_lambdas < SHT_[:, :, 0])
        if min_D_case.sum().item() != 0 : 
            inferred_coord_repeat = inferred_coord.repeat(1, self.n_lambdas).view(-1,2)
            SHT_[min_D_case] = torch.cat((D0_lambdas[min_D_case].view(-1,1), 
                                          inferred_coord_repeat[min_D_case.view(-1)]), dim=1)
            self.SHT[self.detection_incomplete] = SHT_
        
        SHT_ = self.SHT[self.detection_incomplete]
        mask = torch.arange(self.n_lambdas).to(device).unsqueeze(0).repeat(self.n_samples, 1)
        mask = mask >= (self.lambda_idx[indices].view(-1,1)+1)
        min_D_, min_D_idx_ = torch.min(SHT_[:, :, 0] + mask, -1) 
        detect_l_case = (min_D_ <= self.detect_threshold[indices])
        
        if end_step : 
            detect_l_case = (min_D_ <= 1.0)
        
        num_of_detection = detect_l_case.sum().item()
        
        if num_of_detection != 0:
            landmark_detection_mask_ = self.landmark_detection_mask[self.detection_incomplete]
            landmark_coords_ = self.landmark_coords[self.detection_incomplete]
            SHT_ = self.SHT[self.detection_incomplete]
            
            landmark_detection_mask_[detect_l_case] = 1
            self.landmark_detection_mask[self.detection_incomplete] = landmark_detection_mask_
            
            l_coord = SHT_[detect_l_case][torch.arange(num_of_detection), min_D_idx_[detect_l_case], 1:]
            landmark_coords_[detect_l_case] = l_coord
            self.landmark_coords[self.detection_incomplete] = landmark_coords_
        
        current_timestep = self.detection_time.max().item()
        thr_control_step = current_timestep - self.thr_control_start   # [n_l]
        lambda_control_step = current_timestep - self.lambda_control_start # [n_l]
        
        control_threshold_case = (thr_control_step >= 0) * (thr_control_step % self.thr_freq == 0)
        if control_threshold_case.sum() != 0 : 
            self.detect_threshold[control_threshold_case] += self.thr_increase[control_threshold_case]
        
        control_lambda_case = (lambda_control_step >= 0) * (lambda_control_step % self.lambda_freq == 0)
        if control_lambda_case.sum() != 0 : 
            self.lambda_idx[control_lambda_case] = torch.clamp(self.lambda_idx[control_lambda_case]+1, 0, self.n_lambdas-1)
        
        return D0_current_lambda


    def stepsize_calc(self, E0) : 
        thresholds = 0.001 + 0.003 * torch.arange(self.max_stepsize - 1).to(device)
        stepsize_ = (E0.view(-1,1) > thresholds).sum(dim=1) + 1
        self.stepsize = torch.zeros(self.n_landmarks, self.batch_size).long().to(device)
        self.stepsize[self.detection_incomplete] = stepsize_
        return stepsize_
        
        
    def act_to_env_calc(self, direction):
        act_to_env = torch.zeros(self.n_landmarks, self.batch_size, 2).long().to(device)
        act_to_env_ = act_to_env[self.detection_incomplete]
        stepsize_ = self.stepsize[self.detection_incomplete]
        
        act_to_env_[direction == 0] += torch.LongTensor([-1,-1]).to(device)
        act_to_env_[direction == 1] += torch.LongTensor([-1,0]).to(device)
        act_to_env_[direction == 2] += torch.LongTensor([-1,1]).to(device)
        act_to_env_[direction == 3] += torch.LongTensor([0,-1]).to(device)
        act_to_env_[direction == 4] += torch.LongTensor([0,1]).to(device)
        act_to_env_[direction == 5] += torch.LongTensor([1,-1]).to(device)
        act_to_env_[direction == 6] += torch.LongTensor([1,0]).to(device)
        act_to_env_[direction == 7] += torch.LongTensor([1,1]).to(device)
        
        act_to_env[self.detection_incomplete] = act_to_env_ * stepsize_.view(-1, 1)
        return act_to_env



class Agent_relative(nn.Module):
    def __init__(self, batch_size, maximum_stage):
        super(Agent_relative, self).__init__()
        self.img_size = [255, 255]
        self.view_size = [27, 27]
        
        self.eps = 0.5
        self.max_stepsize = 30
        
        self.batch_size = batch_size
        self.maximum_stage = maximum_stage
        
        self.coordnet = CoordNet()
        self.relcoordnet = RelCoordnet()
        self.featnet = FeatNet()
        
    
    def set_hyperparams(self, thr_init, thr_control_start, thr_freq, thr_increase, 
                        lambda_control_start, lambda_freq, lambda_ft) : 
        self.thr_init = thr_init
        self.thr_control_start = thr_control_start
        self.thr_freq = thr_freq
        self.thr_increase = thr_increase
        self.lambda_control_start = lambda_control_start
        self.lambda_freq = lambda_freq
        self.lambda_ft = lambda_ft
        self.lambda_cd = 1 - self.lambda_ft
    
    
    def init_tensors(self):
        self.detection_time = torch.zeros(self.n_landmarks, self.batch_size).to(device)  # for recording detection complete time
        self.landmark_detection_mask = torch.zeros(self.n_landmarks, self.batch_size).to(device)  # 0: not detected
        self.detection_incomplete = torch.ones(self.n_landmarks, self.batch_size).bool().to(device)  # True : detection imcomplete
        self.landmark_coords = torch.zeros(self.n_landmarks, self.batch_size, 2).to(device)  # for recording inferred landmark coords
        self.lambda_cd = 1 - self.lambda_ft   # [n_l, n_lambdas]
        self.n_lambdas = self.lambda_ft.size(1)
        self.lambda_idx = torch.zeros(self.n_landmarks).long().to(device)
        self.SHT = 5*torch.ones(self.n_landmarks, self.batch_size, self.n_lambdas, 3).to(device)   # [n_l, b, n_lambda, 3]
        self.detect_threshold = self.thr_init   # [n_l]
        
    
    def set_prior(self, state, stage):
        if stage == 1 : 
            self.featnet.load_state_dict(state['f_net'])
            self.prior_z_ft_all = state['preferred_z_f']   # [11, 128]
            self.coordnet.load_state_dict(state['c_net'])
            self.prior_z_cd_all = state['preferred_z_c']
            self.relcoordnet.load_state_dict(state['relative_c_net'])
            self.prior_z_rcd_all = state['preferred_z_rc_22']
            
        if stage <= self.maximum_stage :
            self.prior_z_ft = self.prior_z_ft_all[0].view(1, 128)
            self.prior_z_cd = self.prior_z_cd_all[0].view(1, 128)
        
        else : 
            self.prior_z_ft = self.prior_z_ft_all[1:] 
            self.prior_z_rcd = self.prior_z_rcd_all[1:]
        
        self.stage = stage
        self.n_landmarks = self.prior_z_ft.size(0)
        self.init_tensors()
    
    
    def progress_detection(self, o0, dirpolnet, reference_c=None, end_step=False):
        self.detection_time += self.detection_incomplete
        if reference_c != None : 
            reference_c = reference_c.repeat(self.n_landmarks, 1, 1)
            reference_c = reference_c[self.detection_incomplete]
        
        self.n_samples = self.detection_incomplete.sum()
        o0 = o0[self.detection_incomplete]
        
        indices = self.detection_incomplete.nonzero(as_tuple=False)[:,0]
        
        D_ft, D_cd, inferred_coord,\
            embedding_ft, embedding_cd, prior_ft, prior_cd = self.calc_D(o0, reference_c, indices)
        
        E0 = self.detection_check(D_ft, D_cd, inferred_coord, indices, end_step)
        
        _ = self.stepsize_calc(E0)
        
        if end_step : 
            return None
        
        lambda_ft = self.lambda_ft[indices]
        lambda_ft = lambda_ft[torch.arange(self.n_samples), self.lambda_idx[indices]].view(-1,1)
        lambda_cd = 1-lambda_ft
        
        direction_prob = dirpolnet(embedding_ft, embedding_cd,
                                   prior_ft, prior_cd,
                                   lambda_ft, lambda_cd)
        direction_prob = F.softmax(direction_prob, dim=1)
        direction = torch.argmax(direction_prob, dim=1)
        
        act_to_env = self.act_to_env_calc(direction)
        
        self.detection_incomplete = self.landmark_detection_mask == 0
        if self.detection_incomplete.sum() == 0:
            return None
        
        return act_to_env
    
    
    def calc_D(self, o, reference_c, indices):
        embedding_ft = self.featnet(o[:, :2])  # [n_samples, 128]
        selected_prior_ft = self.prior_z_ft[indices]
        D_ft = torch.norm(embedding_ft - selected_prior_ft, dim=1)**2
        
        if reference_c == None : 
            embedding_cd, inferred_coord = self.coordnet(o[:, 1:])
            selected_prior_cd = self.prior_z_cd[indices]
            D_cd = torch.norm(embedding_cd - selected_prior_cd, dim=1)**2
            
        else : 
            reference_c_input = reference_c.view(self.n_samples, 2, 1, 1) * torch.ones(self.n_samples, 2, 27, 27).to(device)
            embedding_cd, relative_coord = self.relcoordnet(o[:, 1:], reference_c_input)
            selected_prior_cd = self.prior_z_rcd[indices]
            D_cd = torch.norm(embedding_cd - selected_prior_cd, dim=1)**2
            inferred_coord = reference_c + relative_coord
                
        return D_ft, D_cd, inferred_coord, embedding_ft, embedding_cd, selected_prior_ft, selected_prior_cd
    
    
    def calc_distance(self, D_ft, D_cd, indices): 
        D_lambdas = self.lambda_ft[indices] * D_ft.unsqueeze(1) \
            + self.lambda_cd[indices] * D_cd.unsqueeze(1)
        return D_lambdas
        

    def detection_check(self, D_ft, D_cd, inferred_coord, indices, end_step):
        D0_lambdas = self.calc_distance(D_ft, D_cd, indices)
        D0_current_lambda = D0_lambdas[torch.arange(self.n_samples), self.lambda_idx[indices]]
        
        SHT_ = self.SHT[self.detection_incomplete] 
        min_D_case = (D0_lambdas < SHT_[:, :, 0])
        if min_D_case.sum().item() != 0 : 
            inferred_coord_repeat = inferred_coord.repeat(1, self.n_lambdas).view(-1,2)
            SHT_[min_D_case] = torch.cat((D0_lambdas[min_D_case].view(-1,1), 
                                          inferred_coord_repeat[min_D_case.view(-1)]), dim=1)
            self.SHT[self.detection_incomplete] = SHT_
        
        SHT_ = self.SHT[self.detection_incomplete]
        mask = torch.arange(self.n_lambdas).to(device).unsqueeze(0).repeat(self.n_samples, 1)
        mask = mask >= (self.lambda_idx[indices].view(-1,1)+1)
        min_D_, min_D_idx_ = torch.min(SHT_[:, :, 0] + mask, -1) 
        detect_l_case = (min_D_ <= self.detect_threshold[indices])
        
        if end_step : 
            detect_l_case = (min_D_ <= 1.0)
        
        num_of_detection = detect_l_case.sum().item()
        
        if num_of_detection != 0:
            landmark_detection_mask_ = self.landmark_detection_mask[self.detection_incomplete]
            landmark_coords_ = self.landmark_coords[self.detection_incomplete]
            SHT_ = self.SHT[self.detection_incomplete]
            
            landmark_detection_mask_[detect_l_case] = 1
            self.landmark_detection_mask[self.detection_incomplete] = landmark_detection_mask_
            
            l_coord = SHT_[detect_l_case][torch.arange(num_of_detection), min_D_idx_[detect_l_case], 1:]
            landmark_coords_[detect_l_case] = l_coord
            self.landmark_coords[self.detection_incomplete] = landmark_coords_
        
        current_timestep = self.detection_time.max().item()
        thr_control_step = current_timestep - self.thr_control_start   # [n_l]
        lambda_control_step = current_timestep - self.lambda_control_start # [n_l]
        
        control_threshold_case = (thr_control_step >= 0) * (thr_control_step % self.thr_freq == 0)
        if control_threshold_case.sum() != 0 : 
            self.detect_threshold[control_threshold_case] += self.thr_increase[control_threshold_case]
        
        control_lambda_case = (lambda_control_step >= 0) * (lambda_control_step % self.lambda_freq == 0)
        if control_lambda_case.sum() != 0 : 
            self.lambda_idx[control_lambda_case] = torch.clamp(self.lambda_idx[control_lambda_case]+1, 0, self.n_lambdas-1)
        
        return D0_current_lambda


    def stepsize_calc(self, E0) : 
        thresholds = 0.001 + 0.003 * torch.arange(self.max_stepsize - 1).to(device)
        stepsize_ = (E0.view(-1,1) > thresholds).sum(dim=1) + 1
        self.stepsize = torch.zeros(self.n_landmarks, self.batch_size).long().to(device)
        self.stepsize[self.detection_incomplete] = stepsize_
        return stepsize_
    

    def act_to_env_calc(self, direction):
        act_to_env = torch.zeros(self.n_landmarks, self.batch_size, 2).long().to(device)
        act_to_env_ = act_to_env[self.detection_incomplete]
        stepsize_ = self.stepsize[self.detection_incomplete]
        
        act_to_env_[direction == 0] += torch.LongTensor([-1,-1]).to(device)
        act_to_env_[direction == 1] += torch.LongTensor([-1,0]).to(device)
        act_to_env_[direction == 2] += torch.LongTensor([-1,1]).to(device)
        act_to_env_[direction == 3] += torch.LongTensor([0,-1]).to(device)
        act_to_env_[direction == 4] += torch.LongTensor([0,1]).to(device)
        act_to_env_[direction == 5] += torch.LongTensor([1,-1]).to(device)
        act_to_env_[direction == 6] += torch.LongTensor([1,0]).to(device)
        act_to_env_[direction == 7] += torch.LongTensor([1,1]).to(device)
        
        act_to_env[self.detection_incomplete] = act_to_env_ * stepsize_.view(-1, 1)
        
        return act_to_env
    
