import torch
import torch.nn as nn

from .parallelism_prior import ParallelCommCostMean, MaxMemoryMean


OOM_DEFAULT_VALUE = -1.


class Network(nn.Module):
    
    def __init__(self, dim, hidden_width=100, hidden_depth=1, activation=None):
        super(Network, self).__init__()
        DEFAULT_ACT = nn.ReLU
        self.activate = DEFAULT_ACT() if activation is None else activation
        self.fc_first = nn.Linear(dim, hidden_width)
        self.fc_hidden = nn.ModuleList([nn.Linear(hidden_width, hidden_width) for _ in range(hidden_depth-1)])
        self.fc_last = nn.Linear(hidden_width, 1)
        
    def forward(self, x):
        x = 2. * (x - 0.5)  # to offset for x \in [0, 1]^d, to centre the square more
        x = self.activate(self.fc_first(x))
        for h in self.fc_hidden:
            x = self.activate(h(x))
        return self.fc_last(x).squeeze(-1)
    
    
class ParallelismInformedLearner(nn.Module):
    
    def __init__(
        self, 
        dim, 
        x_upper_bound,
        max_mem_GB,
        consider_comm=True,
    ):
        super(ParallelismInformedLearner, self).__init__()
        self.time_prior = ParallelCommCostMean(x_upper_bound=x_upper_bound, consider_comm=consider_comm)
        self.mem_prior = MaxMemoryMean(x_upper_bound=x_upper_bound, max_mem_GB=max_mem_GB)
        self.c = torch.nn.Parameter(torch.tensor(0.))
        
    def pred_time(self, x):
        return self.time_prior.forward(x)
    
    def pred_mem(self, x):
        return self.mem_prior.forward(x)
        
    def forward(self, x):
        s = torch.sigmoid(self.c.exp() * self.pred_mem(x)) 
        return s * OOM_DEFAULT_VALUE + (1. - s) * (self.pred_time(x) - OOM_DEFAULT_VALUE)
    

class ParallelismInformedNetwork(nn.Module):
    
    def __init__(
        self, 
        dim, 
        x_upper_bound,
        max_mem_GB,
        hidden_width=128,
        hidden_depth=1,
        activation=None,
        consider_comm=True,
        factor_scale=0.1,
    ):
        super(ParallelismInformedNetwork, self).__init__()
        self.time_nn = Network(dim=dim, hidden_width=hidden_width, hidden_depth=hidden_depth, activation=activation)
        self.time_prior = ParallelCommCostMean(x_upper_bound=x_upper_bound, consider_comm=consider_comm)
        self.mem_nn = Network(dim=dim, hidden_width=hidden_width, hidden_depth=hidden_depth, activation=activation)
        self.mem_prior = MaxMemoryMean(x_upper_bound=x_upper_bound, max_mem_GB=max_mem_GB)
        self.factor_scale = factor_scale
        self.c = torch.nn.Parameter(torch.tensor(0.))
        
    def pred_time(self, x):
        a = self.time_prior.forward(x)
        b = self.time_nn(x)
        return torch.relu(a + self.factor_scale * b.reshape(a.shape))
    
    def pred_mem(self, x):
        a = self.mem_prior.forward(x)
        b = self.mem_nn(x)
        return a + self.factor_scale * b.reshape(a.shape)
        
    def forward(self, x):
        s = torch.sigmoid(self.c.exp() * self.pred_mem(x)) 
        return s * OOM_DEFAULT_VALUE + (1. - s) * (self.pred_time(x) - OOM_DEFAULT_VALUE)
