import torch
import torch.nn as nn
import torch.nn.functional as F

class HyperSegModel(torch.nn.Module):
  
    def __init__(self, model,embd = None):
        super(HyperSegModel, self).__init__()
        
        self.model = model
       
    def forward(self, x, ts, embd = None):
        
        ys = self.model(x,embd)
        binary_mask = (torch.sum(ts[1], dim=1) != 0).type(torch.FloatTensor).unsqueeze(1).cuda()

        # semantic loss: depth-wise cross entropy
        # TODO: better loss balancing method
        loss1 = F.nll_loss(ys[0], ts[0], ignore_index=-1) / 10  

        # depth loss: l1 norm
        loss2 = torch.sum(torch.abs(ys[1] - ts[1]) * binary_mask) / torch.nonzero(binary_mask).size(0)

        return torch.stack([loss1, loss2])
        
        
class SegNet(nn.Module):
    def __init__(self):
        super(SegNet, self).__init__()
        
        ################ SegNet Parameters ##################################
        
        # number of filters for incoders and decoders
        filter = [64, 128, 256, 512, 512]
        
        # number of shared parameter tensors in each conv layer
        # each tensor has size: filter / ns * filter / ns * filter_size * filter_size
        # can be set to 2,4,6,8....
        # TODO: more flexible shared methods
        self.ns = 2        
        self.share_part =  self.ns
        
        # number of classes for segmentation
        self.class_nb = 7

        # define down/up sampling functions
        self.down_sampling = nn.MaxPool2d(kernel_size=2, stride=2, return_indices=True)
        self.up_sampling = nn.MaxUnpool2d(kernel_size=2, stride=2)
        
        
        ################ Hypernetwork ##################################
        
        # preference embedding 
        self.embd = nn.ParameterList()
        for _ in range(2):
            self.embd.append( nn.Parameter(data=torch.Tensor(64),requires_grad=True))
            torch.nn.init.normal_(self.embd[-1], mean=0., std=0.1)
            
            
        # chunk embedding 
        self.chunk_embd = nn.ParameterList()
        for _ in range(20):
            self.chunk_embd.append( nn.Parameter(data=torch.Tensor(64),requires_grad=True))
            torch.nn.init.normal_(self.chunk_embd[-1], mean=0., std=0.1)
            
        
        # hyper fc layers 
        self.hyper_fcnets_inc = self.fc_layer()
        self.hyper_fcnets_dec = self.fc_layer()
        self.hyper_fcnets_head = self.fc_layer_head()
        
        # incoder conv layers 
        self.W_inc = nn.ParameterList()
        self.W_inc_conv1 = nn.ParameterList()
        self.W_inc_conv2 = nn.ParameterList()
        
        # incoder bn layers
        self.bn_inc = nn.ModuleList()
        self.bn_inc_conv1 = nn.ModuleList()
        self.bn_inc_conv2 = nn.ModuleList()
        
        # incoder parameter tensors
        for i in range(5):
            if i == 0:
                # incoder block
                self.W_inc.append(nn.Parameter(data=torch.Tensor(3 * filter[0] // self.share_part * 3 * 3, self.ns),requires_grad=True))
                torch.nn.init.normal_(self.W_inc[-1], mean=0., std=0.1)
                self.bn_inc.append(self.bn_layer(filter[0]))
                # incoder conv1
                self.W_inc_conv1.append(nn.Parameter(data=torch.Tensor(filter[0] * filter[0] // self.share_part * 3 * 3, self.ns),requires_grad=True))
                torch.nn.init.normal_(self.W_inc_conv1[-1], mean=0., std=0.1)
                self.bn_inc_conv1.append(self.bn_layer(filter[0]))
                
            else:
                # incoder block
                self.W_inc.append(nn.Parameter(data=torch.Tensor(filter[i - 1] * filter[i] // self.share_part * 3 * 3, self.ns),requires_grad=True))
                torch.nn.init.normal_(self.W_inc[-1], mean=0., std=0.1)
                self.bn_inc.append(self.bn_layer(filter[i]))
                # incoder conv1
                self.W_inc_conv1.append(nn.Parameter(data=torch.Tensor(filter[i] * filter[i] // self.share_part * 3 * 3, self.ns),requires_grad=True))
                torch.nn.init.normal_(self.W_inc_conv1[-1], mean=0., std=0.1)
                self.bn_inc_conv1.append(self.bn_layer(filter[i]))     
                if i > 1:
                    # incoder conv2
                    self.W_inc_conv2.append(nn.Parameter(data=torch.Tensor(filter[i] * filter[i] // self.share_part * 3 * 3, self.ns),requires_grad=True))
                    torch.nn.init.normal_(self.W_inc_conv2[-1], mean=0., std=0.1)
                    self.bn_inc_conv2.append(self.bn_layer(filter[i]))   
                
        
        
        # decoder conv layers 
        self.W_dec = nn.ParameterList()
        self.W_dec_conv1 = nn.ParameterList()
        self.W_dec_conv2 = nn.ParameterList()
        
        # decoder bn layers
        self.bn_dec = nn.ModuleList()
        self.bn_dec_conv1 = nn.ModuleList()
        self.bn_dec_conv2 = nn.ModuleList()
       
        # decoder parameter tensors
        for i in range(5):
            if i == 4:
                # decoder 
                self.W_dec.append(nn.Parameter(data=torch.Tensor(filter[0] * filter[0] // self.share_part * 3 * 3, self.ns),requires_grad=True))
                torch.nn.init.normal_(self.W_dec[-1], mean=0., std=0.1)
                self.bn_dec.append(self.bn_layer(filter[0]))
                # decoder conv1 
                self.W_dec_conv1.append(nn.Parameter(data=torch.Tensor(filter[0] * filter[0] // self.share_part * 3 * 3, self.ns),requires_grad=True))
                torch.nn.init.normal_(self.W_dec_conv1[-1], mean=0., std=0.1)
                self.bn_dec_conv1.append(self.bn_layer(filter[0]))
            else:
                # decoder 
                self.W_dec.append(nn.Parameter(data=torch.Tensor(filter[4 - i] * filter[3 - i] // self.share_part * 3 * 3, self.ns),requires_grad=True))
                torch.nn.init.normal_(self.W_dec[-1], mean=0., std=0.1)
                self.bn_dec.append(self.bn_layer(filter[3 - i]))
                # decoder conv1
                self.W_dec_conv1.append(nn.Parameter(data=torch.Tensor(filter[3 - i] * filter[3 - i] // self.share_part * 3 * 3, self.ns),requires_grad=True))
                torch.nn.init.normal_(self.W_dec_conv1[-1], mean=0., std=0.1)
                self.bn_dec_conv1.append(self.bn_layer(filter[3 - i]))
                if i < 3:
                    # decoder conv2
                    self.W_dec_conv2.append(nn.Parameter(data=torch.Tensor(filter[3 - i] * filter[3 - i] // self.share_part * 3 * 3, self.ns),requires_grad=True))
                    torch.nn.init.normal_(self.W_dec_conv2[-1], mean=0., std=0.1)
                    self.bn_dec_conv2.append(self.bn_layer(filter[3 - i]))
                
        
        # shared top conv layers 
        self.W_top = nn.ParameterList()
        self.bn_top = nn.ModuleList()
        for i in range(2):
                self.W_top.append(nn.Parameter(data=torch.Tensor(filter[0] * filter[0] * 1 * 1, 10),requires_grad=True))
                torch.nn.init.normal_(self.W_top[-1], mean=0., std=0.1)
                self.bn_top.append(self.relu_layer(filter[0]))
               
                
        # task specific heads 
        self.W_t1 = nn.ParameterList()
        
        self.W_t1.append(nn.Parameter(data=torch.Tensor(filter[0] * filter[0] * 3 * 3, 10),requires_grad=True))
        torch.nn.init.normal_(self.W_t1[-1], mean=0., std=0.1)
        self.W_t1.append(nn.Parameter(data=torch.Tensor(filter[0] * self.class_nb * 1 * 1, 10),requires_grad=True))
        torch.nn.init.normal_(self.W_t1[-1], mean=0., std=0.1)
        
        self.W_t2 = nn.ParameterList()
        self.W_t2.append(nn.Parameter(data=torch.Tensor(filter[0] * filter[0] * 3 * 3, 10),requires_grad=True))
        torch.nn.init.normal_(self.W_t2[-1], mean=0., std=0.1)
        self.W_t2.append(nn.Parameter(data=torch.Tensor(filter[0] * 1 * 1 * 1, 10),requires_grad=True))
        torch.nn.init.normal_(self.W_t2[-1], mean=0., std=0.1)
        
        
    def bn_layer(self, channel):
        bn_block = nn.Sequential(
                nn.BatchNorm2d(num_features=channel, affine = False, track_running_stats = False),
                nn.ReLU(inplace=True)
            )
        return bn_block
    
    def relu_layer(self, channel):
        relu_block = nn.Sequential(
                nn.ReLU(inplace=True)
            )
        return relu_block

    def fc_layer(self):
        fc_block = nn.Sequential(
                nn.Linear(128,100),
                nn.ReLU(inplace=True),
                nn.Linear(100,100),
                nn.ReLU(inplace=True),
                nn.Linear(100,120)
            )
        return fc_block
    
    def fc_layer_head(self):
        fc_head_block = nn.Sequential(
                nn.Linear(64,100),
                nn.ReLU(inplace=True),
                nn.Linear(100,100),
                nn.ReLU(inplace=True),
                nn.Linear(100,120)
            )
        return fc_head_block


    def forward(self, x, pref = None):
        
        
        ################ SegNet Parameters ##################################
        
        # nums of filter
        filter = [64, 128, 256, 512, 512]
        
        # indices for down/up sampling
        indices = [0] * 5   
        
    
        ################ Hypernetwork Inference #############################
        
        # obtain the preference-weighted embedding
        pref_embd = pref[0] * self.embd[0] + pref[1] * self.embd[1]
        
        # concatenate preference embedding + chunk embedding
        concated_embd_list = [torch.cat([pref_embd, self.chunk_embd[i]]) for i in range(10)]
        concated_embd = torch.stack(concated_embd_list)
           
        # generate the middle embeddings for different parts of the main MTL network
        inc_mid_embd = self.hyper_fcnets_inc(concated_embd[:5])
        dec_mid_embd = self.hyper_fcnets_dec(concated_embd[5:])
        task_mid_embd = self.hyper_fcnets_head(pref_embd)
        
        # generate the parameters for incoder layers
        W_inc_para_list = []
        W_inc_conv1_para_list = []
        W_inc_conv2_para_list = []
        
        para_temp = {}
        for i in range(5):
            if i == 0:
                for k in range(self.share_part):
                    para_temp[k] = F.linear(inc_mid_embd[i][self.ns * k : self.ns * k + self.ns], self.W_inc[i])
                W_inc_para = torch.stack([para_temp[k] for k in range(self.share_part)])
                W_inc_para_list.append(W_inc_para.reshape(filter[0], 3,3,3))
                
                for k in range(self.share_part):
                    para_temp[k] = F.linear(inc_mid_embd[i][self.ns * k + self.ns : self.ns * k + 2 * self.ns], self.W_inc_conv1[i])
                W_inc_conv1_para = torch.stack([para_temp[k] for k in range(self.share_part)])
                W_inc_conv1_para_list.append(W_inc_conv1_para.reshape(filter[0],filter[0],3,3))
                
            else:
                for k in range(self.share_part):
                    para_temp[k] = F.linear(inc_mid_embd[i][self.ns * k : self.ns * k + self.ns], self.W_inc[i])
                W_inc_para = torch.stack([para_temp[k] for k in range(self.share_part)])
                W_inc_para_list.append(W_inc_para.reshape(filter[i],filter[i - 1],3,3))
                
                for k in range(self.share_part):
                    para_temp[k] = F.linear(inc_mid_embd[i][self.ns * k + self.ns : self.ns * k + 2 * self.ns], self.W_inc_conv1[i])
                W_inc_conv1_para = torch.stack([para_temp[k] for k in range(self.share_part)])
                W_inc_conv1_para_list.append(W_inc_conv1_para.reshape(filter[i],filter[i],3,3))
                
                if i > 1:
                    for k in range(self.share_part):
                        para_temp[k] = F.linear(inc_mid_embd[i][self.ns * k + 2 * self.ns: self.ns * k + 3 * self.ns], self.W_inc_conv2[i - 2])
                    W_inc_conv2_para = torch.stack([para_temp[k] for k in range(self.share_part)])
                    W_inc_conv2_para_list.append(W_inc_conv2_para.reshape(filter[i],filter[i],3,3))
        
        
        # generate the parameters for decoder layers
        W_dec_para_list = []
        W_dec_conv1_para_list = []
        W_dec_conv2_para_list = []
        
        para_temp = {}
        for i in range(5):
            if i == 4:
                for k in range(self.share_part):
                    para_temp[k] = F.linear(dec_mid_embd[i][self.ns * k : self.ns * k + self.ns], self.W_dec[i])
                W_dec_para = torch.stack([para_temp[k] for k in range(self.share_part)])
                W_dec_para_list.append(W_dec_para.reshape(filter[0],filter[0],3,3))
                
                for k in range(self.share_part):
                    para_temp[k] = F.linear(dec_mid_embd[i][self.ns * k + self.ns : self.ns * k + 2 * self.ns], self.W_dec_conv1[i])
                W_dec_conv1_para = torch.stack([para_temp[k] for k in range(self.share_part)])
                W_dec_conv1_para_list.append(W_dec_conv1_para.reshape(filter[0],filter[0],3,3))
                
            else:
                for k in range(self.share_part):
                    para_temp[k] = F.linear(dec_mid_embd[i][self.ns * k : self.ns * k + self.ns], self.W_dec[i])
                W_dec_para = torch.stack([para_temp[k] for k in range(self.share_part)])
                W_dec_para_list.append(W_dec_para.reshape(filter[3 - i],filter[4 - i],3,3))
                
                for k in range(self.share_part):
                    para_temp[k] = F.linear(dec_mid_embd[i][self.ns * k + self.ns : self.ns * k + 2 * self.ns], self.W_dec_conv1[i])
                W_dec_conv1_para = torch.stack([para_temp[k] for k in range(self.share_part)])
                W_dec_conv1_para_list.append(W_dec_conv1_para.reshape(filter[3 - i],filter[3 - i],3,3))
                
                if i < 3:
                    for k in range(self.share_part):
                        para_temp[k] = F.linear(dec_mid_embd[i][self.ns * k + 2 * self.ns: self.ns * k + 3 * self.ns], self.W_dec_conv2[i])
                    W_dec_conv2_para = torch.stack([para_temp[k] for k in range(self.share_part)])
                    W_dec_conv2_para_list.append(W_dec_conv2_para.reshape(filter[3 - i],filter[3 - i],3,3))
        
        # generate parameters for top conv layers 
        W_top_para_list = []
        for i in range(2):
            W_top_para = F.linear(task_mid_embd[10 * i : 10 * i + 10],self.W_top[i])
            W_top_para_list.append(W_top_para.reshape(filter[0],filter[0],1,1))
            
        # generate parameters for task-specific heads
        W_t1_para_list = []
        W_t2_para_list = []
        
        W_t1_para = F.linear(task_mid_embd[20:30],self.W_t1[0])
        W_t1_para_list.append(W_t1_para.reshape(filter[0],filter[0],3,3))
        
        W_t1_para = F.linear(task_mid_embd[30:40],self.W_t1[1])
        W_t1_para_list.append(W_t1_para.reshape(self.class_nb,filter[0],1,1))
        
        W_t2_para = F.linear(task_mid_embd[40:50],self.W_t2[0])
        W_t2_para_list.append(W_t2_para.reshape(filter[0],filter[0],3,3))
        
        W_t2_para = F.linear(task_mid_embd[50:60],self.W_t2[1])
        W_t2_para_list.append(W_t2_para.reshape(1,filter[0],1,1))
        
        
        ################ Main MTL Network Inference #########################
        
        # input 
        share_feature = x
        
        # incoder for segnet
        for i in range(5):
            if i == 0:
                # inc layer
                share_feature = F.conv2d(share_feature, W_inc_para_list[i], padding = 1)
                share_feature = self.bn_inc[i](share_feature)
                # inc conv1 layer
                share_feature = F.conv2d(share_feature,W_inc_conv1_para_list[i], padding = 1)
                share_feature = self.bn_inc_conv1[i](share_feature)
            else:
                # inc layer
                share_feature = F.conv2d(share_feature,W_inc_para_list[i], padding = 1)
                share_feature = self.bn_inc[i](share_feature)
                # inc conv1 layer
                share_feature = F.conv2d(share_feature,W_inc_conv1_para_list[i], padding = 1)
                share_feature = self.bn_inc_conv1[i](share_feature)
                if i > 1:
                    # inc conv2 layer
                    share_feature = F.conv2d(share_feature,W_inc_conv2_para_list[i - 2], padding = 1)
                    share_feature = self.bn_inc_conv2[i - 2](share_feature)
            # down sampling 
            share_feature, indices[i] = self.down_sampling(share_feature)
        
        
        # decoder for segnet
        for i in range(5):
            # up sampling
            share_feature = self.up_sampling(share_feature, indices[-1 - i])
            if i == 4:
                # dec layer
                share_feature = F.conv2d(share_feature, W_dec_para_list[i], padding = 1)
                share_feature = self.bn_dec[i](share_feature)
                # dec conv1 layer
                share_feature = F.conv2d(share_feature,W_dec_conv1_para_list[i], padding = 1)
                share_feature = self.bn_dec_conv1[i](share_feature)
            else:
                # dec layer
                share_feature = F.conv2d(share_feature,W_dec_para_list[i], padding = 1)
                share_feature = self.bn_dec[i](share_feature)
                # dec conv1 layer
                share_feature = F.conv2d(share_feature,W_dec_conv1_para_list[i], padding = 1)
                share_feature = self.bn_dec_conv1[i](share_feature)
                if i < 3:
                    # dec conv2 layer
                    share_feature = F.conv2d(share_feature,W_dec_conv2_para_list[i], padding = 1)
                    share_feature = self.bn_dec_conv2[i](share_feature)
                    
        # shared top conv layers 
        for i in range(2):
            share_feature = F.conv2d(share_feature,W_top_para_list[i], padding=0)             
            share_feature = self.bn_top[i](share_feature)
    

        # task specific heads
        t1 = F.conv2d(share_feature,W_t1_para_list[0], padding=1)
        t1 = F.conv2d(t1,W_t1_para_list[1], padding=0)
        
        t2 = F.conv2d(share_feature,W_t2_para_list[0], padding=1)
        t2 = F.conv2d(t2,W_t2_para_list[1], padding=0)
        
        # output
        t1_pred = F.log_softmax(t1, dim=1)
        t2_pred = t2
        
        return [t1_pred, t2_pred] 
    
    
    def model_fit(self, x_pred1, x_output1, x_pred2, x_output2):
        # binary mark to mask out undefined pixel space
        binary_mask = (torch.sum(x_output2, dim=1) != 0).type(torch.FloatTensor).unsqueeze(1).cuda()

        # semantic loss: depth-wise cross entropy
        loss1 = F.nll_loss(x_pred1, x_output1, ignore_index=-1)

        # depth loss: l1 norm
        loss2 = torch.sum(torch.abs(x_pred2 - x_output2) * binary_mask) / torch.nonzero(binary_mask, as_tuple=False).size(0)
    
        return [loss1, loss2]
    
    def compute_miou(self, x_pred, x_output):
        _, x_pred_label = torch.max(x_pred, dim=1)
        x_output_label = x_output
        batch_size = x_pred.size(0)
        class_nb = x_pred.size(1)
        device = x_pred.device
        for i in range(batch_size):
            true_class = 0
            first_switch = True
            invalid_mask = (x_output[i] >= 0).float()
            for j in range(class_nb):
                pred_mask = torch.eq(x_pred_label[i], j * torch.ones(x_pred_label[i].shape).long().to(device))
                true_mask = torch.eq(x_output_label[i], j * torch.ones(x_output_label[i].shape).long().to(device))
                mask_comb = pred_mask.float() + true_mask.float()
                union = torch.sum((mask_comb > 0).float() * invalid_mask)  # remove non-defined pixel predictions
                intsec = torch.sum((mask_comb > 1).float())
                if union == 0:
                    continue
                if first_switch:
                    class_prob = intsec / union
                    first_switch = False
                else:
                    class_prob = intsec / union + class_prob
                true_class += 1
            if i == 0:
                batch_avg = class_prob / true_class
            else:
                batch_avg = class_prob / true_class + batch_avg
        return batch_avg / batch_size

 
    def compute_iou(self, x_pred, x_output):
        _, x_pred_label = torch.max(x_pred, dim=1)
        x_output_label = x_output
        batch_size = x_pred.size(0)
        for i in range(batch_size):
            if i == 0:
                pixel_acc = torch.div(
                    torch.sum(torch.eq(x_pred_label[i], x_output_label[i]).float()),
                    torch.sum((x_output_label[i] >= 0).float()))
            else:
                pixel_acc = pixel_acc + torch.div(
                    torch.sum(torch.eq(x_pred_label[i], x_output_label[i]).float()),
                    torch.sum((x_output_label[i] >= 0).float()))
        return pixel_acc / batch_size

    def depth_error(self, x_pred, x_output):
        device = x_pred.device
        binary_mask = (torch.sum(x_output, dim=1) != 0).unsqueeze(1).to(device)
        x_pred_true = x_pred.masked_select(binary_mask)
        x_output_true = x_output.masked_select(binary_mask)
        abs_err = torch.abs(x_pred_true - x_output_true)
        rel_err = torch.abs(x_pred_true - x_output_true) / x_output_true
        return (torch.sum(abs_err) / torch.nonzero(binary_mask, as_tuple=False).size(0)).item(), \
               (torch.sum(rel_err) / torch.nonzero(binary_mask, as_tuple=False).size(0)).item()  
