from enum import Enum, auto

import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from tqdm import tqdm

from gatv3.pyg import GATv1Layer, GATv2Layer


class GAT(torch.nn.Module):
    def __init__(self, base_layer, in_channels, hidden_channels, out_channels, num_layers, num_heads,
                 dropout, device, saint, use_layer_norm, use_residual, use_residual_linear, convolve, lambda_policy,
                 share_weights_value, share_weights_score, gcn_mode):

        super(GAT, self).__init__()

        self.layers = torch.nn.ModuleList()
        kwargs = {'bias': True}

        self.is_gat = base_layer in (GATv1Layer, GATv2Layer)
        if self.is_gat:
            kwargs.update({
                'heads': num_heads,
                'convolve': convolve, 'lambda_policy': lambda_policy,
                'share_weights_value': share_weights_value, 'share_weights_score': share_weights_score,
                'gcn_mode': gcn_mode
            })
        else:
            num_heads = 1

        self.use_layer_norm = use_layer_norm
        self.use_residual = use_residual
        self.use_residual_linear = use_residual_linear
        self.dropout = dropout
        self.device = device
        self.saint = saint
        self.non_linearity = F.relu
        self.num_layers = num_layers

        self.layers.append(base_layer(in_channels, hidden_channels // num_heads, **kwargs))

        self.layer_norms = torch.nn.ModuleList()
        self.residuals = torch.nn.ModuleList()

        if use_layer_norm:
            self.layer_norms.append(nn.LayerNorm(hidden_channels))
        if use_residual_linear and use_residual:
            self.residuals.append(nn.Linear(in_channels, hidden_channels))

        for _ in range(num_layers - 2):
            self.layers.append(base_layer(hidden_channels, hidden_channels // num_heads, **kwargs))
            if use_layer_norm:
                self.layer_norms.append(nn.LayerNorm(hidden_channels))
            if use_residual_linear and use_residual:
                self.residuals.append(nn.Linear(hidden_channels, hidden_channels))

        if self.is_gat:
            kwargs.update({'heads': 1})

        self.layers.append(base_layer(hidden_channels, out_channels, **kwargs))
        if use_residual_linear and use_residual:
            self.residuals.append(nn.Linear(hidden_channels, out_channels))

        self.dummy = nn.Parameter(torch.zeros([]), requires_grad=True)

        print(f"learnable_params: {sum(p.numel() for p in list(self.parameters()) if p.requires_grad)}")

    def reset_parameters(self):
        for layer in self.layers:
            layer.reset_parameters()
        for layer in self.layer_norms:
            layer.reset_parameters()
        for layer in self.residuals:
            layer.reset_parameters()

    @property
    def lmbda(self):
        return [float(l.lmbda) for l in self.layers] if self.is_gat else [0.] * len(self.layers)

    @property
    def lmbda2(self):
        return [float(l.lmbda2) for l in self.layers] if self.is_gat else [1.] * len(self.layers)

    def forward_neighbor_sampler(self, x, adjs):
        for i, (edge_index, _, size) in enumerate(adjs):
            x = x + self.dummy * 0.
            x_target = x[:size[1]]  # Target nodes are always placed first.
            new_x = checkpoint(self.layers[i], x, edge_index, size[1], preserve_rng_state=False)
            #new_x, x_target = new_x[:size[1]], x_target[:size[1]]

            if i != self.num_layers - 1:
                new_x = self.non_linearity(new_x)
            if 0 < i < self.num_layers - 1 and self.use_residual:
                x = new_x + x_target
            else:
                x = new_x
            if i < self.num_layers - 1:
                if self.use_layer_norm:
                    x = self.layer_norms[i](x)
                x = F.dropout(x, p=self.dropout, training=self.training)
        return x

    def exp_forward_neighbor_sampler(self, x, adjs):
        for i, (edge_index, _, size) in enumerate(adjs):
            x = x + self.dummy * 0.
            x_target = x[:size[1]]  # Target nodes are always placed first.
            new_x = checkpoint(self.layers[i], x, edge_index, size[1], preserve_rng_state=False)
            #new_x, x_target = new_x[:size[1]], x_target[:size[1]]

            if self.use_residual:
                if self.use_residual_linear:
                    x = new_x + self.residuals[i](x_target)
                else:
                    x = new_x + x_target
            else:
                x = new_x

            if i < self.num_layers - 1:
                x = self.non_linearity(x)
                if self.use_layer_norm:
                    x = self.layer_norms[i](x)
                x = F.dropout(x, p=self.dropout, training=self.training)
        return x

    def forward_saint(self, x, adj_t):
        for i, layer in enumerate(self.layers[:-1]):
            x = x + self.dummy * 0.
            new_x = checkpoint(layer, x, adj_t, preserve_rng_state=False)
            new_x = self.non_linearity(new_x)
            # residual
            if i > 0 and self.use_residual:
                if self.use_residual_linear:
                    x = new_x + self.residuals[i](x)
                else:
                    x = new_x + x
                x = new_x + x
            else:
                x = new_x
            if self.use_layer_norm:
                x = self.layer_norms[i](x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = checkpoint(self.layers[-1], x, adj_t, preserve_rng_state=False)
        #x = self.layers[-1](x, adj_t)
        return x

    def forward(self, x, adjs):
        if self.saint:
            return self.forward_saint(x, adjs)
        else:
            return self.forward_neighbor_sampler(x, adjs)

    @torch.no_grad()
    def inference(self, x, subgraph_loader):
        pbar = tqdm(total=x.size(0) * len(self.layers), leave=False, desc="Layer", disable=False)
        pbar.set_description('Evaluating')
        for i, layer in enumerate(self.layers[:-1]):
            xs = []
            for batch_size, n_id, adj in subgraph_loader:
                edge_index, _, size = adj.to(self.device)
                x_source = x[n_id].to(self.device)
                x_target = x_source[:size[1]]  # Target nodes are always placed first.
                new_x = layer(x_source, edge_index, size[1])
                #new_x, x_target = new_x[:size[1]], x_target[:size[1]]

                new_x = self.non_linearity(new_x)
                # residual
                if i > 0 and self.use_residual:
                    x_target = new_x + x_target
                else:
                    x_target = new_x
                if self.use_layer_norm:
                    x_target = self.layer_norms[i](x_target)
                # x_target = F.dropout(x_target, p=self.dropout, training=self.training)
                xs.append(x_target.cpu())
                pbar.update(batch_size)
            x = torch.cat(xs, dim=0)
        xs = []
        for batch_size, n_id, adj in subgraph_loader:
            edge_index, _, size = adj.to(self.device)
            x_source = x[n_id].to(self.device)
            x_target = x_source[:size[1]]  # Target nodes are always placed first.
            new_x = self.layers[-1](x_source, edge_index, size[1])
            #new_x, x_target = new_x[:size[1]], x_target[:size[1]]
            xs.append(new_x.cpu())
            pbar.update(batch_size)
        x = torch.cat(xs, dim=0)
        pbar.close()
        return x

    def exp_inference(self, x, subgraph_loader):
        pbar = tqdm(total=x.size(0) * len(self.layers), leave=False, desc="Layer", disable=False)
        pbar.set_description('Evaluating')
        for i, layer in enumerate(self.layers):
            xs = []
            for batch_size, n_id, adj in subgraph_loader:
                edge_index, _, size = adj.to(self.device)
                x_source = x[n_id].to(self.device)
                x_target = x_source[:size[1]]  # Target nodes are always placed first.
                new_x = layer(x_source, edge_index, size[1])
                #new_x, x_target = new_x[:size[1]], x_target[:size[1]]

                if self.use_residual:
                    if self.use_residual_linear:
                        x_target = new_x + self.residuals[i](x_target)
                    else:
                        x_target = new_x + x_target
                else:
                    x_target = new_x
                if i < self.num_layers - 1:
                    x_target = self.non_linearity(x_target)
                    if self.use_layer_norm:
                        x_target = self.layer_norms[i](x_target)

                xs.append(x_target.cpu())
                pbar.update(batch_size)
            x = torch.cat(xs, dim=0)
        pbar.close()
        return x


class GAT_TYPE(Enum):
    GAT = auto()
    GAT2 = auto()
    GCN = auto()

    @staticmethod
    def from_string(s):
        try:
            return GAT_TYPE[s]
        except KeyError:
            raise ValueError()

    def __str__(self):
        if self is GAT_TYPE.GAT:
            return "GAT"
        elif self is GAT_TYPE.GAT2:
            return "GAT2"
        elif self is GAT_TYPE.GCN:
            return "GCN"
        return "NA"

    def get_model(self, in_channels, hidden_channels, out_channels, num_layers, num_heads, dropout, device, saint,
                  use_layer_norm, use_residual, use_residual_linear, convolve, lambda_policy, share_weights_value,
                  share_weights_score):
        if self is GAT_TYPE.GAT:
            return GAT(GATv1Layer, in_channels, hidden_channels, out_channels, num_layers, num_heads, dropout, device,
                       saint, use_layer_norm, use_residual, use_residual_linear, convolve, lambda_policy,
                       share_weights_value, share_weights_score, False)
        elif self is GAT_TYPE.GAT2:
            return GAT(GATv2Layer, in_channels, hidden_channels, out_channels, num_layers, num_heads, dropout, device,
                       saint, use_layer_norm, use_residual, use_residual_linear, convolve, lambda_policy,
                       share_weights_value, share_weights_score, False)
        elif self is GAT_TYPE.GCN:
#            return GAT(GraphConv, in_channels, hidden_channels, out_channels, num_layers, num_heads, dropout, device,
            return GAT(GATv2Layer, in_channels, hidden_channels, out_channels, num_layers, 1, dropout, device,
                       saint, use_layer_norm, use_residual, use_residual_linear, True, None,
                       True, True, True)

    def get_base_layer(self):
        if self is GAT_TYPE.GAT:
            return GATv1Layer
        elif self is GAT_TYPE.GAT2:
            return GATv2Layer
        elif self is GAT_TYPE.GCN:
#            return GraphConv
            return GATv2Layer
