import torch
import torch.nn.functional as F
from src.utils import graph
import logging
log = logging.getLogger(__name__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class DiffusionTransition:
    def __init__(self, x_classes: int, e_classes: int, y_classes: int, diffuse_edges: bool=True, 
                 node_idx_to_mask: torch.tensor=torch.ones((1,), dtype=torch.long), edge_idx_to_mask=None):
        self.x_classes = x_classes
        self.e_classes = e_classes
        self.y_classes = y_classes
        self.diffuse_edges = diffuse_edges

        # The following lines make sure that masked node types (node_idx_to_mask) don't mess with the diffusion
        # i.e. that we cannot transition from or to them
        self.Qts_x[...,:,node_idx_to_mask] = 0.
        self.Qts_x[...,node_idx_to_mask,:] = 0.
        self.Qts_x[...,node_idx_to_mask,node_idx_to_mask] = 1.
        if edge_idx_to_mask is not None:
            self.Qts_e[...,edge_idx_to_mask] = 0.
            self.Qts_e[...,edge_idx_to_mask,:] = 0.
            self.Qts_e[...,edge_idx_to_mask,edge_idx_to_mask] = 1.

        # reweighting
        self.Qts_x /= self.Qts_x.sum(dim=-1).unsqueeze(dim=-1)
        self.Qts_e /= self.Qts_e.sum(dim=-1).unsqueeze(dim=-1)

        self.Qt_bars_x[...,:,node_idx_to_mask] = 0.
        self.Qt_bars_x[...,node_idx_to_mask,:] = 0.
        self.Qt_bars_x[...,node_idx_to_mask,node_idx_to_mask] = 1.
        
        if edge_idx_to_mask is not None:
            self.Qt_bars_e[...,edge_idx_to_mask] = 0.
            self.Qt_bars_e[...,edge_idx_to_mask,:] = 0.
            self.Qt_bars_e[...,edge_idx_to_mask,edge_idx_to_mask] = 1.
        
        # reweighting
        self.Qt_bars_x /= self.Qt_bars_x.sum(dim=-1).unsqueeze(dim=-1)
        self.Qt_bars_e /= self.Qt_bars_e.sum(dim=-1).unsqueeze(dim=-1)
        
        # Make limiting distribution not have any probability mass on the masked nodes 
        self.x_limit[node_idx_to_mask] = 0.
        self.e_limit[edge_idx_to_mask] = 0.
        self.x_limit = self.x_limit / self.x_limit.sum()
        self.e_limit = self.e_limit / self.e_limit.sum()

        # Move everything to GPU if we have one
        self.to_device(device)

    def to_device(self, device):
        self.Qts_x = self.Qts_x.to(device)
        self.Qts_e = self.Qts_e.to(device)
        self.Qts_y = self.Qts_y.to(device)
        self.Qt_bars_x = self.Qt_bars_x.to(device)
        self.Qt_bars_e = self.Qt_bars_e.to(device)
        self.Qt_bars_y = self.Qt_bars_y.to(device)
        self.x_limit = self.x_limit.to(device)
        self.e_limit = self.e_limit.to(device)
        self.y_limit = self.y_limit.to(device)

    def get_Qt(self, t, device):
        idx = t[:,0].long()
        if self.diffuse_edges:
            return graph.PlaceHolder(X=self.Qts_x[idx], E=self.Qts_e[idx], y=self.Qts_y[idx])
        else:
            # Identity transform for edges
            edge_transition = torch.eye(self.e_classes, device=device)[None].repeat(t.shape[0], 1, 1)
            return graph.PlaceHolder(X=self.Qts_x[idx], E=edge_transition, y=self.Qts_y[idx])

    def get_Qt_bar(self, t, device):
        idx = t[:,0].long()
        if self.diffuse_edges:
            return graph.PlaceHolder(X=self.Qt_bars_x[idx], E=self.Qt_bars_e[idx], y=self.Qt_bars_y[idx])
        else:
            # Identity transform for edges
            edge_transition = torch.eye(self.e_classes, device=device)[None].repeat(t.shape[0], 1, 1)
            return graph.PlaceHolder(X=self.Qt_bars_x[idx], E=edge_transition, y=self.Qt_bars_y[idx])
            
    def get_limit_dist(self):
        limit_dist = graph.PlaceHolder(X=self.x_limit, E=self.e_limit, y=self.y_limit)
        return limit_dist

class AbsorbingStateTransitionMaskNoEdge(DiffusionTransition):
    def __init__(self, timesteps: int, x_classes: int, e_classes: int, y_classes: int, diffuse_edges=True, 
                 abs_state_position_x=-1, abs_state_position_e=0, node_idx_to_mask: torch.tensor=3*torch.ones((1,), 
                 dtype=torch.long), edge_idx_to_mask=None):
        T = timesteps
        
        self.Qts_x = [torch.eye(x_classes, device=device)] + [1/(T - t + 1) * F.one_hot(torch.tensor([abs_state_position_x], dtype=torch.long, device=device),x_classes).repeat(x_classes, 1) 
                                               + (1-1/(T - t + 1)) * torch.eye(x_classes, device=device, dtype=torch.float) for t in range(1,timesteps+1)]
        self.Qt_bars_x = [torch.eye(x_classes, device=device)] + [t/(T) * F.one_hot(torch.tensor([abs_state_position_x], dtype=torch.long, device=device),x_classes).repeat(x_classes, 1) 
                                               + (1 - t/(T)) * torch.eye(x_classes, device=device, dtype=torch.float) for t in range(1,timesteps+1)]
        self.Qts_e = [torch.eye(e_classes, device=device)] + [1/(T - t + 1) * F.one_hot(torch.tensor([abs_state_position_e], dtype=torch.long, device=device),e_classes).repeat(e_classes, 1)
                                                  + (1-1/(T - t + 1)) * torch.eye(e_classes, device=device, dtype=torch.float) for t in range(1,timesteps+1)]
        self.Qt_bars_e = [torch.eye(e_classes, device=device)] + [t/(T) * F.one_hot(torch.tensor([abs_state_position_e], dtype=torch.long, device=device),e_classes).repeat(e_classes, 1)
                                                    + (1 - t/(T)) * torch.eye(e_classes, device=device, dtype=torch.float) for t in range(1,timesteps+1)]
        self.Qts_y = [torch.eye(y_classes+1, device=device)] + [1/(T - t + 1) * F.one_hot(torch.tensor([y_classes], dtype=torch.long, device=device),y_classes+1).repeat(y_classes+1, 1)
                                                    + (1-1/(T - t + 1)) * torch.eye(y_classes+1, device=device, dtype=torch.float) for t in range(1,timesteps+1)]
        self.Qt_bars_y = [torch.eye(y_classes+1, device=device)] + [t/(T) * F.one_hot(torch.tensor([y_classes], dtype=torch.long, device=device),y_classes+1).repeat(y_classes+1, 1)
                                                        + (1 - t/(T)) * torch.eye(y_classes+1, device=device, dtype=torch.float) for t in range(1,timesteps+1)]
        self.Qts_x = torch.stack(self.Qts_x)
        self.Qt_bars_x = torch.stack(self.Qt_bars_x)
        self.Qts_e = torch.stack(self.Qts_e)
        self.Qt_bars_e = torch.stack(self.Qt_bars_e)
        self.Qts_y = torch.stack(self.Qts_y)
        self.Qt_bars_y = torch.stack(self.Qt_bars_y)

        self.x_limit = F.one_hot(torch.tensor([abs_state_position_x], dtype=torch.long), x_classes)[0].float() # remove batch dimension
        self.e_limit = F.one_hot(torch.tensor([abs_state_position_e], dtype=torch.long), e_classes)[0].float() # TODO Make sure that this is correct
        self.y_limit = F.one_hot(torch.tensor([y_classes], dtype=torch.long), y_classes+1)[0].float()

        super().__init__(x_classes, e_classes, y_classes, diffuse_edges, node_idx_to_mask, edge_idx_to_mask)