#!/usr/bin/env python
# coding: utf-8

import torch
import torch.nn as nn

class dyn2DMat(nn.Module):
    def __init__(self, num_input, num_output, num_Qs, q_dim, device, p=1, _scale=5):
        super(dyn2DMat, self).__init__() 

        self.num_input = num_input
        self.num_output = num_output
        self.num_Qs = num_Qs
        self.q_dim = q_dim
        self.norm_p = p
        self._scale = _scale
        
        self.input_Qs = torch.nn.Parameter(1*torch.rand(num_Qs, num_input, q_dim, device=device))
        self.output_Qs = torch.nn.Parameter(1*torch.rand(num_Qs, num_output, q_dim, device=device))
        self.lambdas_io = torch.nn.Parameter(torch.randn(num_Qs, 1, 1, device=device))
        
        
    def forward(self, _prec=-1):
        
        if _prec != -1:
            input_Qs = _prec*(torch.div(self.input_Qs, _prec, rounding_mode='floor'))
            output_Qs = _prec*(torch.div(self.output_Qs, _prec, rounding_mode='floor'))
        else:
            input_Qs = self.input_Qs
            output_Qs = self.output_Qs
        
        dist_io = self._scale*(torch.cdist(input_Qs, output_Qs, p=self.norm_p))         
        W_io = torch.sum(dist_io*self.lambdas_io,0)
        return W_io

