import numpy as np
import torch
from torch import nn

class Sheaf_NNet(nn.Module):
    def __init__(self, dimx, dimy, nnet_list=[], nsmat=60):
        super(Sheaf_NNet, self).__init__()
        self.dimx = dimx
        self.dimy = dimy
        self.nnet_list = nnet_list


        self.fc_smat = nn.Sequential(nn.Linear(self.dimx, nsmat),
                                     nn.ReLU(),
                                     nn.Linear(nsmat, nsmat),
                                     nn.ReLU(),
                                     nn.Linear(nsmat, nsmat),
                                     nn.ReLU(),
                                     nn.Linear(nsmat, nsmat),
                                     nn.ReLU(),
                                     nn.Linear(nsmat, nsmat),
                                     nn.ReLU(),
                                     nn.Linear(nsmat, nsmat),
                                     nn.ReLU(),
                                     nn.Linear(nsmat, nsmat),
                                     nn.ReLU(),
                                     nn.Linear(nsmat, self.dimy * self.dimx))

        self.regressor = nn.Linear(self.dimy, 1, bias=False)


    def forward(self, w, x, f):
        smat = torch.reshape(self.fc_smat(x), (-1, self.dimy, self.dimx))
        loss = 0.0

        xi = torch.reshape(x, (x.shape[0], x.shape[1], 1))
        q = torch.reshape(torch.bmm(smat, xi), (-1, self.dimy))
        # print('torch.min(q) = ' + str(torch.min(q)))
        # print('torch.max(q) = ' + str(torch.max(q)))
        y = torch.tensordot(w, q, dims=([1], [0]))
        xmap = torch.transpose(torch.diagonal(torch.tensordot(smat, y, dims=([1], [1])), dim1=0, dim2=2), 0, 1)
        loss_smap = torch.mean((xmap - x) * (xmap - x)) * self.dimx

        if self.dimy < self.dimx:
            rmat = torch.torch.bmm(smat, torch.permute(smat, (0, 2, 1)))
            target_matrix = torch.zeros((rmat.shape[0], self.dimy, self.dimy))
            for idx in range(target_matrix.shape[0]):
                target_matrix[idx, :, :] = torch.eye(self.dimy)

        else:
            rmat = torch.torch.bmm(torch.permute(smat, (0, 2, 1)), smat)
            target_matrix = torch.zeros((rmat.shape[0], self.dimx, self.dimx))
            for idx in range(target_matrix.shape[0]):
                target_matrix[idx, :, :] = torch.eye(self.dimx)

        
        rmat = rmat - target_matrix
        loss_orth = torch.sqrt(torch.mean(rmat * rmat) * self.dimy * self.dimy)
        
        smat_proj = torch.reshape(self.fc_smat(xmap), (-1, self.dimy, self.dimx))
        loss_cons = torch.mean((smat_proj - smat) * (smat_proj - smat)) * self.dimy * self.dimx
        df = self.regressor(y) - f
        loss_sres = torch.mean(df * df)
        return (loss_orth, loss_cons, loss_smap, loss_sres)

### what ideas do you have?
### there are several things that I can do





