from Requierments import *

class MultiTaskNetwork(nn.Module):
    def __init__(self, input_dim, shared_dim, task1_dim, task2_dim, sigmaweights, W1_init=None, W2_init=None, WS_init=None):
        super(MultiTaskNetwork, self).__init__()
        torch.manual_seed(seed)

        self.shared_layer = nn.Linear(input_dim, shared_dim)
        self.task1_layer = nn.Linear(shared_dim, task1_dim)
        self.task2_layer = nn.Linear(shared_dim, task2_dim)
        torch.nn.init.constant_(self.task1_layer.bias, 0)
        torch.nn.init.constant_(self.task2_layer.bias, 0)
        torch.nn.init.constant_(self.shared_layer.bias, 0)

        # Initialization
        if W1_init is None:
          torch.nn.init.normal_(self.task1_layer.weight, mean=0.0, std=sigmaweights)
          torch.nn.init.normal_(self.task2_layer.weight, mean=0.0, std=sigmaweights)


        else:
          with torch.no_grad():
            self.task1_layer.weight.copy_(W1_init)
            self.task2_layer.weight.copy_(W2_init)
            self.shared_layer.weight.copy_(WS_init)


        lhs = WS_init @ WS_init.t()                           # Shape: [shared_dim, shared_dim]
        rhs = W1_init.t() @ W1_init + W2_init.t() @ W2_init              # Shape: [shared_dim, shared_dim]
        print("||W_S W_S^T - (W1^T W1 + W2^T W2)|| =", torch.norm(lhs - rhs))


    def forward(self, x):
        # Shared representation
        shared_out = self.shared_layer(x)

        # Task-specific outputs
        task1_out = self.task1_layer(shared_out)
        task2_out = self.task2_layer(shared_out)

        return task1_out, task2_out

class SingleTaskNetwork(nn.Module):
    def __init__(self, input_dim, shared_dim, task1_dim,sigmaweights,W1_init=None, W2_init=None, WS_init=None):
        super(SingleTaskNetwork, self).__init__()
        # Shared layer
        torch.manual_seed(seed)

        self.shared_layer = nn.Linear(input_dim, shared_dim)
        self.task1_layer = nn.Linear(shared_dim, task1_dim)
        torch.nn.init.constant_(self.shared_layer.bias, 0)
        torch.nn.init.constant_(self.task1_layer.bias, 0) 
        # Initialization
        if W1_init is None:
       # shape: [task1_dim, shared_dim]
          torch.nn.init.normal_(self.task1_layer.weight, mean=0.0, std=sigmaweights)

        else:
          with torch.no_grad():
            self.task1_layer.weight.copy_(W1_init)
            self.shared_layer.weight.copy_(WS_init)

        #check if weights are balanced
        lhs = WS_init @ WS_init.t()                           # Shape: [shared_dim, shared_dim]
        rhs = W1_init.t() @ W1_init          # Shape: [shared_dim, shared_dim]
        print("||W_S W_S^T - (W1^T W1 + W2^T W2)|| =", torch.norm(lhs - rhs))


    def forward(self, x):
        # Shared representation
        shared_out = self.shared_layer(x)

        # Task-specific outputs
        task1_out = self.task1_layer(shared_out)


        return task1_out


class DeepMultiTaskNetwork(nn.Module):
    def __init__(self, input_dim, shared_dim, task1_dim, task2_dim, nsharedL, ntaskL, sigmaweights, W1_init=None, W2_init=None, WS_init=None):
        super(DeepMultiTaskNetwork, self).__init__()

        torch.manual_seed(seed)
        self.nsharedL=nsharedL
        self.ntaskL=ntaskL

        self.sharedLayers=nn.ModuleList()
        for i in range(0, nsharedL):
          if i==0:
            self.sharedLayers.append(nn.Linear(input_dim, shared_dim))
          else:
            self.sharedLayers.append(nn.Linear(shared_dim, shared_dim))
          nn.init.constant_(self.sharedLayers[i].bias, 0)

        self.taskLayers=nn.ModuleList()
        for i in range(0, 2):
          self.taskiLayers=nn.ModuleList()
          task1_dim=task1_dim if i==0 else task2_dim
          for i in range(0,ntaskL):
            if i==0:
              self.taskiLayers.append(nn.Linear(shared_dim, task1_dim))
            else:
              self.taskiLayers.append(nn.Linear(task1_dim, task1_dim))
            nn.init.constant_(self.taskiLayers[i].bias, 0)
          self.taskLayers.append(self.taskiLayers)


        # Initialization
        if W1_init is None:
          for j in range(0,2):
            for i in range(0,ntaskL):
              torch.nn.init.normal_(self.taskLayers[j][i].weight, mean=0.0, std=sigmaweights)
              torch.nn.init.normal_(self.taskLayers[j][i].weight, mean=0.0, std=sigmaweights)


    def forward(self, x):
        # Shared representation
        for i in range(0,self.nsharedL):
          if i==0:
            shared_out=self.sharedLayers[i](x)
          else:
            shared_out=self.sharedLayers[i](shared_out)

        for i in range(0,self.ntaskL):
          if i==0:
            task1_out = self.taskLayers[0][i](shared_out)
            task2_out = self.taskLayers[1][i](shared_out)
          else:
            task1_out = self.taskLayers[0][i](task1_out)
            task2_out = self.taskLayers[1][i](task1_out)
        return task1_out, task2_out
        
class NonLinearMultiTaskNetwork(nn.Module):
    def __init__(self, input_dim, shared_dim, task1_dim, task2_dim, sigmaweights, W1_init=None, W2_init=None, WS_init=None):
        super(NonLinearMultiTaskNetwork, self).__init__()
        torch.manual_seed(seed)

        self.shared_layer = nn.Linear(input_dim, shared_dim)
        self.task1_layer = nn.Linear(shared_dim, task1_dim)
        self.task2_layer = nn.Linear(shared_dim, task2_dim)
        torch.nn.init.constant_(self.task1_layer.bias, 0)
        torch.nn.init.constant_(self.task2_layer.bias, 0)
        torch.nn.init.constant_(self.shared_layer.bias, 0)
        
        self.task1_act = nn.ReLU()
        self.task2_act = nn.ReLU()


        # Initialization
        if W1_init is None:
          torch.nn.init.normal_(self.task1_layer.weight, mean=0.0, std=sigmaweights)
          torch.nn.init.normal_(self.task2_layer.weight, mean=0.0, std=sigmaweights)


        else:
          with torch.no_grad():
            self.task1_layer.weight.copy_(W1_init)
            self.task2_layer.weight.copy_(W2_init)
            self.shared_layer.weight.copy_(WS_init)


        lhs = WS_init @ WS_init.t()                           # Shape: [shared_dim, shared_dim]
        rhs = W1_init.t() @ W1_init + W2_init.t() @ W2_init              # Shape: [shared_dim, shared_dim]
        print("||W_S W_S^T - (W1^T W1 + W2^T W2)|| =", torch.norm(lhs - rhs))


    def forward(self, x):
        # Shared representation
        shared_out = self.shared_layer(x)

        # Task-specific outputs
        task1_out = self.task1_act(self.task1_layer(shared_out))
        task2_out = self.task2_act(self.task2_layer(shared_out))

        return task1_out, task2_out