import math
from typing import Union, Tuple

import torch
import torch.nn.functional as F
from torch_geometric.nn.dense.linear import Linear

from .generalgat import GeneralGATLayer


class GATAnsatzLayer(GeneralGATLayer):
    def __init__(
            self,
            in_channels: Union[int, Tuple[int, int]],
            out_channels: int,
            negative_slope: float = 0.2,
            add_self_loops: bool = True,
            heads: int = 1,
            bias: bool = True,
            convolve: bool = True,
            lambda_policy: str = None,  # [None, 'learn1', 'learn2', 'learn12', 'gcn_gat', 'individual']
            gcn_mode: bool = False,
            share_weights_score: bool = True,
            share_weights_value: bool = True,
            version: str = None,
            mu_norm: float = 0.0,
            d: int = 0,
            p: float = 0.0,
            q: float = 0.0,
            use_ansatz: bool = False,  # Only used in the synthetic experiments
            use_partial_ansatz: bool = False,  # Only used in the synthetic experiments
            **kwargs,
    ):

        super().__init__(in_channels, out_channels, negative_slope, add_self_loops,
                         heads, bias, convolve, lambda_policy, gcn_mode,
                         **kwargs)
        assert self.heads == 1, f'Only one head implemented for {str(self)}'
        # assert share_weights_score == True
        # assert share_weights_value == True
        assert out_channels == 1
        assert d > 0
        assert mu_norm > 0.0
        assert p > 0.0
        assert q > 0.0

        self.mu_norm = mu_norm
        self.d = d
        self.q = q
        self.p = p
        self.use_ansatz = use_ansatz
        self.use_partial_ansatz = use_partial_ansatz

        self.version = version
        mid_hannels = 4 if version == 'v1' else 8
        bias_in = False if version == 'v1' else True
        self.att_in = Linear(out_channels * 2, mid_hannels, bias=bias_in, weight_initializer='glorot')
        self.att_out = Linear(mid_hannels, 1, bias=False, weight_initializer='glorot')

        self.lin_l = Linear(in_channels, out_channels, bias=False, weight_initializer='glorot')
        if share_weights_score:
            self.lin_r = self.lin_l
        else:
            self.lin_r = Linear(in_channels, out_channels, bias=False, weight_initializer='glorot')

        if share_weights_value:
            self.lin_v = self.lin_l if self.flow == 'source_to_target' else self.lin_r
        else:
            self.lin_v = Linear(in_channels, out_channels, bias=False, weight_initializer='glorot')

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        self.att_in.reset_parameters()
        self.att_out.reset_parameters()
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()
        self.lin_v.reset_parameters()

        if self.use_ansatz and self.use_partial_ansatz:
            self.fix_parameters(False)

    def fix_parameters(self, partial=False):
        if partial:
            self.att_in.weight.requires_grad = False
            self.att_out.weight.requires_grad = False
            self.lin_l.weight.requires_grad = False
            self.lin_r.weight.requires_grad = False
            self.lin_v.weight.requires_grad = False

        mu_norm = self.mu_norm
        mu = mu_norm / (math.sqrt(self.d))

        if self.version == 'v1':
            w = mu / mu_norm
            self.lin_l.weight.data = torch.ones_like(self.lin_l.weight.data) * w
            self.lin_r.weight.data = torch.ones_like(self.lin_r.weight.data) * w
            self.lin_v.weight.data = torch.ones_like(self.lin_v.weight.data) * w

            S = torch.tensor([[1., 1.],
                              [-1., -1.],
                              [1., -1.],
                              [-1., 1.]])

            self.att_in.weight.data = S

            R = torch.tensor([[1., 1., -1., -1.]])

            self.att_out.weight.data = R
        elif self.version == 'v2':
            p = self.p
            q = self.q
            if self.convolve:  # GCAT
                b_cte = (p - q) / (p + 2 * q)
            else:  # GAT
                b_cte = 1.0

            w = (mu / mu_norm)
            self.lin_l.weight.data = torch.ones_like(self.lin_l.weight.data) * w
            self.lin_r.weight.data = torch.ones_like(self.lin_r.weight.data) * w
            self.lin_v.weight.data = torch.ones_like(self.lin_v.weight.data) * w

            if not partial:
                if self.gcn_mode:
                    self.bias.data = 0.0 * torch.ones_like(self.bias.data)
                else:
                    self.bias.data = -torch.ones_like(self.bias.data) * mu_norm * b_cte / 2.
            S = torch.tensor([[1., 1.],
                              [-1., -1.],
                              [1., -1.],
                              [-1., 1.],
                              [0., 1.],
                              [1., 0.],
                              [0., -1.],
                              [-1., 0.]])

            b1 = [-3. / 2., ] * 4
            b2 = [-1. / 2., ] * 4
            b = [*b1, *b2]

            b = torch.tensor([b]) * mu_norm * b_cte

            self.att_in.weight.data = S
            if not partial:
                self.att_in.bias.data = b

            cte = 7 / mu_norm

            R = cte * torch.tensor([[2., -2., -2., 2., -1., -1., -1., -1.]])  # R

            self.att_out.weight.data = R

    def get_x_r(self, x):
        return self.lin_r(x)

    def get_x_l(self, x):
        return self.lin_l(x)

    def get_x_v(self, x):
        return self.lin_v(x)

    def compute_score(self, x_i, x_j, index, ptr, size_i):
        # print(f"att_in.weight: {self.att_in.weight.data.mean()} {self.att_in.weight.data.std()}")
        # print(f"lin_l.weight: {self.lin_l.weight.data.mean()}  {self.lin_l.weight.data.std()}")
        tmp = torch.cat([x_i, x_j], dim=-1).squeeze(1)
        tmp = self.att_in(tmp)
        deltas = F.leaky_relu(tmp, self.negative_slope)
        out = self.att_out(deltas)

        return out.unsqueeze(1)

    def forward(self, x, edge_index, size_target: int = None, return_attention_info: bool = False):
        if self.use_ansatz:
            self.fix_parameters(self.use_partial_ansatz)

        return super().forward(x, edge_index, size_target, return_attention_info)

