import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class Conv_DyN(nn.Module):
    def __init__(
        self,
        kernel_size,
        input_channels,
        output_channels,
        stride,
        padding,
        num_CHs,
        q_dim, 
        norm_p, 
        SCALE_FACTOR_conv=0.01
    ):
        super(Conv_DyN, self).__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.norm_p = norm_p
        self.num_CHs = num_CHs

        self.Conv_inputQs = torch.nn.Parameter(torch.rand(kernel_size, num_CHs, input_channels, q_dim))
        self.Conv_outputQs = torch.nn.Parameter(torch.rand(kernel_size, num_CHs, output_channels, q_dim))
        
        self.shared_coeff_conv = torch.nn.Parameter(SCALE_FACTOR_conv*torch.tensor([(-1)**h_id for h_id in range(self.num_CHs)]).unsqueeze(1).unsqueeze(2), requires_grad=False)

    def pathIntegrals(self, _inputQs, _outputQs, _coeffs):

        dist_W = torch.cdist(_inputQs, _outputQs, p=self.norm_p)
        return torch.sum(dist_W*_coeffs, 0)

    
    def forward_(self):

        input_ch, output_ch, kernel_size = self.Conv_inputQs.shape[2], self.Conv_outputQs.shape[2], self.Conv_inputQs.shape[0]
        res_conv = torch.zeros(output_ch, input_ch, kernel_size, kernel_size).to(self.Conv_outputQs.device)

        for _i in range(kernel_size):
            for _j in range(kernel_size):
                res_conv[:,:,_j,_i] += self.pathIntegrals(self.Conv_outputQs[_i], self.Conv_inputQs[_j], self.shared_coeff_conv)
        return res_conv

    
    def forward(self, x):

        W_conv = self.forward_()
        out = F.conv2d(x, W_conv, stride=self.stride, padding=self.padding)
        return out


class Linear_DyN(nn.Module):
    def __init__(self, in_features, out_features, num_Hs, q_dim, norm_p, SCALE_FACTOR_fc=0.01, bias=None):
        super(Linear_DyN, self).__init__()
        self.norm_p = norm_p
        self.num_Hs = num_Hs
        
        self.In_Qs = torch.nn.Parameter(torch.rand(num_Hs, in_features, q_dim))
        self.Out_Qs = torch.nn.Parameter(torch.rand(num_Hs, out_features, q_dim))
        
        self.shared_coeff_fc = torch.nn.Parameter(SCALE_FACTOR_fc*torch.tensor([(-1)**h_id for h_id in range(self.num_Hs)]).unsqueeze(1).unsqueeze(2), requires_grad=False)
        self.bias = torch.nn.Parameter(torch.rand(out_features))
        nn.init.constant_(self.bias, 0)

    def pathIntegrals(self, _inputQs, _outputQs, _coeffs):
        
        dist_W = torch.cdist(_inputQs, _outputQs, p=self.norm_p)
        return torch.sum(dist_W*_coeffs, 0)
    def forward_(self):
        return self.pathIntegrals(self.In_Qs, self.Out_Qs, self.shared_coeff_fc).T

    def forward(self, x):
        W_fc = self.forward_()
        return F.linear(x, W_fc, self.bias)
    

class Linear_DyN_NoMat(nn.Module):
    def __init__(self, in_features, out_features, num_Hs, q_dim, norm_p, SCALE_FACTOR_fc=0.001, bias=None):
        super(Linear_DyN_NoMat, self).__init__()
        self.num_Hs = num_Hs
        self.norm_p = norm_p
        self.in_features = in_features
        self.out_features = out_features

        self.In_Qs = torch.nn.Parameter(torch.rand(num_Hs, in_features, q_dim))
        self.Out_Qs = torch.nn.Parameter(torch.rand(num_Hs, out_features, q_dim))
        self.shared_coeff = torch.nn.Parameter(SCALE_FACTOR_fc*torch.tensor([(-1)**h_id for h_id in range(self.num_Hs)], dtype=torch.float32))
        self.bias = torch.nn.Parameter(torch.rand(out_features))
        nn.init.constant_(self.bias, 0)
        nn.init.kaiming_uniform_(self.In_Qs, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.Out_Qs, a=math.sqrt(5))
    
    def pathIntegrals(self):
        dist_io = torch.cdist(self.In_Qs, self.Out_Qs) ** 2
        dist_io = torch.sum(self.shared_coeff.unsqueeze(-1).unsqueeze(-1)*dist_io, dim=0)
        return dist_io

    def forward_(self):
        return self.pathIntegrals().T

    def forward(self, x):
        in_shape = x.shape
        assert in_shape[-1] == self.in_features
        x = x.reshape(-1, self.in_features)
        In_Qs_Square, Out_Qs_Square = self.In_Qs**2, self.Out_Qs**2
        
        In_Qs_Square = In_Qs_Square.sum(-1)*self.shared_coeff.unsqueeze(1)
        S_In_Qs_Square = x @ In_Qs_Square.permute(1, 0)
        S_In_Qs_Square = S_In_Qs_Square.sum(-1).unsqueeze(-1)

        Out_Qs_Square = Out_Qs_Square.sum(-1)*self.shared_coeff.unsqueeze(1)
        x_sum = x.sum(-1).unsqueeze(-1).repeat(1, self.num_Hs)
        S_Out_Qs_Square = x_sum @ Out_Qs_Square

        S_Inner_Product = self.In_Qs.permute(0, 2, 1) @ x.T
        S_Inner_Product = self.Out_Qs @ S_Inner_Product
        S_Inner_Product = S_Inner_Product*self.shared_coeff.unsqueeze(-1).unsqueeze(-1)
        S_Inner_Product = S_Inner_Product.sum(0).permute(1, 0)

        out_shape = list(in_shape)
        out_shape[-1] = self.out_features
        x_out = S_In_Qs_Square + S_Out_Qs_Square - 2 * S_Inner_Product
        x_out = x_out.reshape(*tuple(out_shape))
        return x_out

    def extra_repr(self) -> str:
        return 'In_Qs={}, Out_Qs={}, bias={}'.format(
            self.In_Qs.shape,
            self.Out_Qs.shape,
            self.bias is not None
        )