import torch 
import torch.nn as nn 

import torch_geometric as pyg 
import tqdm 

class WarmupReward(nn.Module):

    def __init__(self, forward_flow, warmup_epochs):
        super(WarmupReward, self).__init__()
        self.forward_flow = forward_flow
        self.warmup_epochs = warmup_epochs

        self.mu = nn.Parameter(torch.tensor(0.), requires_grad=False)
        self.std = nn.Parameter(torch.tensor(1.), requires_grad=False)

        self.alpha = .5

    def forward(self, create_env):
        pbar = tqdm.tqdm(range(self.warmup_epochs))
        for _ in pbar:
            env = create_env()
            while (env.stopped < 1.).all():
                out = self.forward_flow(env, off_policy=False)
                actions = out[0]
                env.apply(actions)
            log_rewards = env.log_reward()
            self.mu.data = self.alpha * self.mu.data + (1 - self.alpha) * log_rewards.mean()
            self.std.data = (self.alpha * self.std.data ** 2 + (1 - self.alpha) * log_rewards.var()).sqrt()
            self.alpha /= 1.5

class ForwardFlow(nn.Module):

    def __init__(self, hidden_dim, num_leaves, masked_value=-1e5):
        super(ForwardFlow, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_leaves = num_leaves
        self.gcn = pyg.nn.GIN(in_channels=num_leaves+2, hidden_channels=hidden_dim, num_layers=2)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, (num_leaves * (num_leaves - 1)) // 2)
        )
        self.mlp = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(), nn.Linear(hidden_dim, (num_leaves * (num_leaves - 1) // 2)))

        self.masked_value = masked_value

    def forward(self, batch_state, off_policy=False, actions=None):
        y = self.gcn(batch_state.expanded_data, batch_state.edge_list_t())
        y = y.view(batch_state.batch_size, -1, self.hidden_dim).sum(dim=1)
        y = self.mlp(y)

        y = y * batch_state.mask + self.masked_value * (1 - batch_state.mask)
        probs = torch.softmax(y, dim=-1)

        if actions is not None: 
            return actions, torch.log(probs[batch_state.batch_ids, actions]) 
         
        if off_policy:
            uniform = torch.where(batch_state.mask == 1., 1., self.masked_value)
            actions = torch.multinomial(torch.softmax(uniform, dim=-1), num_samples=1, replacement=True)
        else:
            actions = torch.multinomial(probs, num_samples=1, replacement=True)

        actions = actions.squeeze()
        return actions, torch.log(probs[batch_state.batch_ids, actions])

class BackwardFlow(nn.Module):

    def forward(self, batch_state, actions):
        return - torch.log(batch_state.num_parents)

class StateFlow(nn.Module): 

    def __init__(self, num_leaves, hidden_dim): 
        super(StateFlow, self).__init__() 
        self.hidden_dim = hidden_dim 
        self.gcn = pyg.nn.GIN(in_channels=num_leaves+2, hidden_channels=hidden_dim, num_layers=2) 
        self.mlp = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(), nn.Linear(hidden_dim, 1)) 

    def forward(self, batch_state): 
        y = self.gcn(batch_state.expanded_data, batch_state.edge_list_t()) 
        y = y.view(batch_state.batch_size, -1, self.hidden_dim).sum(dim=1) 
        y = self.mlp(y) 
        return y.squeeze() 
        # pass 