import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.nn.modules.loss import CrossEntropyLoss


class RegressionTrain(torch.nn.Module):
  
    def __init__(self, model):
        super(RegressionTrain, self).__init__()
        
        self.model = model
        self.ce_loss = CrossEntropyLoss()
    
    def forward(self, x, ts,pref_idx,embd = None):
        n_tasks = 2
        ys = self.model(x,pref_idx,embd)
     
        task_loss = []
        for i in range(n_tasks):
            task_loss.append( self.ce_loss(ys[:,i], ts[:,i]) )
        task_loss = torch.stack(task_loss)

        return task_loss


class RegressionModel(torch.nn.Module):
    def __init__(self, n_tasks):
        super(RegressionModel, self).__init__()
        self.n_tasks = n_tasks
   
    
        #### trainable parameters for HyperNet ####
        self.hyper_fc1_para = nn.Parameter(data=torch.Tensor(20,2),requires_grad=True)
        torch.nn.init.normal_(self.hyper_fc1_para, mean=0., std=1)
        
        self.hyper_fc2_para = nn.Parameter(data=torch.Tensor(20,20),requires_grad=True)
        torch.nn.init.normal_(self.hyper_fc2_para, mean=0., std= 0.1)
        
        self.hyper_fc3_para = nn.Parameter(data=torch.Tensor(10,20),requires_grad=True)
        torch.nn.init.normal_(self.hyper_fc3_para, mean=0., std= 0.1)
            
        self.conv1_para = nn.Parameter(data=torch.Tensor(20 * 1 * 9 * 9, 2),requires_grad=True)
        torch.nn.init.normal_(self.conv1_para, mean=0., std=0.01)
        
        self.conv2_para = nn.Parameter(data=torch.Tensor(10 * 20 * 5 * 5, 2),requires_grad=True)
        torch.nn.init.normal_(self.conv2_para, mean=0., std=0.01)
        
        self.fc1_para = nn.Parameter(data=torch.Tensor(5*5*10*100,2),requires_grad=True)
        torch.nn.init.normal_(self.fc1_para, mean=0., std=0.01)
        
        self.outputpara1 = nn.Parameter(data=torch.Tensor(100 * 10,2),requires_grad=True)
        torch.nn.init.normal_(self.outputpara1, mean=0., std=0.01)
        
        self.outputpara2 = nn.Parameter(data=torch.Tensor(100 * 10,2),requires_grad=True)
        torch.nn.init.normal_(self.outputpara2, mean=0., std=0.01)
        
    def forward(self, x, pref_idx, embd = None):
        
        if embd is None:
            embd = self.embd[pref_idx]
        
        #### hypernetwork ####
        embd = F.relu(F.linear(embd,self.hyper_fc1_para, bias = True))
        embd = F.relu(F.linear(embd,self.hyper_fc2_para, bias = True))
        embd = F.relu(F.linear(embd,self.hyper_fc3_para, bias = True))
        
        
        #### generate the main MTL network's parameters ####
        main_conv1_para = F.linear(embd[:2],self.conv1_para)
        main_conv1_para =  main_conv1_para.reshape(20,1,9,9)
        
        main_conv2_para = F.linear(embd[2:4],self.conv2_para)
        main_conv2_para =  main_conv2_para.reshape(10,20,5,5)
        
        main_fc1_para = F.linear(embd[4:6],self.fc1_para)
        main_fc1_para = main_fc1_para.reshape(100,5*5*10)
        
        
        main_output1_para = F.linear(embd[6:8],self.outputpara1)
        main_output1_para = main_output1_para.reshape(10,100)
        
        main_output2_para = F.linear(embd[8:10],self.outputpara2)
        main_output2_para = main_output2_para.reshape(10,100)
        
        
        #### main MTL network ####
        
        # conv1
        x = F.conv2d(x,main_conv1_para)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        
        # conv2
        x = F.conv2d(x,main_conv2_para)
        x = F.dropout(x, p = 0.2, training = self.training)
        x = F.relu(x)
        x = F.max_pool2d(x, 2, 2)
        
        # fc1 
        x = x.view(-1, 5*5*10)
        x = F.linear(x, main_fc1_para)
        x = F.relu(x)
        
        # task specific output layers
        outs = []
        outs.append(F.linear(x, main_output1_para))
        outs.append(F.linear(x, main_output2_para))
        

        return torch.stack(outs, dim=1)
        
        

