import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
import math
from tqdm import tqdm
from time import perf_counter
from scipy.linalg import svd
import pickle

#torch.cuda.set_device()


class ineffCNN(nn.Module):

    def __init__(self, d1, d2, l1, l2, p1, p2, R, C, K):
        super(ineffCNN, self).__init__()
        self.d1 = d1
        self.d2 = d2
        self.l1 = l1
        self.l2 = l2
        self.p1 = p1
        self.p2 = p2
        self.R = R
        self.C = C
        self.K = K
        self.conv1 = nn.Conv2d(C, R, kernel_size=(1, 1), bias=False)
        self.conv2 = nn.Conv2d(R, R, kernel_size=(l1,l2), groups=R, bias=False)
        self.conv3 = nn.Conv2d(R, K, kernel_size=(1, 1), bias=False)
        self.avg = nn.AvgPool2d((int((d1-l1+1)/p1), int((d2-l2+1)/p2)))
        self.fc1 = nn.Linear(K*p1*p2, 1, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.avg(x)
        x = x.reshape(-1, self.K*self.p1*self.p2)
        x = self.fc1(x)
        return x


def kronecker(a, b):
    return torch.ger(a.view(-1), b.view(-1)).reshape(*(a.size() + b.size())).permute([0, 3, 1, 4, 2, 5]).reshape(a.size(0)*b.size(0),a.size(1)*b.size(1),a.size(2)*b.size(2))


def u_tensor1(i, l, d):
    return torch.cat((torch.zeros((i-1, l), dtype=torch.float32), torch.eye(l, dtype=torch.float32), torch.zeros((d-l-i+1, l), dtype=torch.float32)), dim=0)


def u_tensor(d, l, p):
    res = []
    for i in range(3):
        u1 = []
        for j in range(1, d[i]-l[i]+2):
            u1.append(u_tensor1(j, l[i], d[i]))
        u2 = []
        for k in range(p[i]):
            q = int((d[i]-l[i]+1)/p[i])
            u2.append(sum(u1[k*q:(k+1)*q])/q)
        res.append(torch.cat(u2, dim=1))
    return res


def gen_w(d, l, p, R, K):
    u1, s1, v1 = svd(torch.randn(l[0], l[0]))
    u2, s2, v2 = svd(torch.randn(l[1], l[1]))
    u3, s3, v3 = svd(torch.randn(l[2], l[2]))
    u4, s4, v4 = svd(torch.randn(K, K))
    u1 = torch.from_numpy(u1)
    u2 = torch.from_numpy(u2)
    u3 = torch.from_numpy(u3)
    u4 = torch.from_numpy(u4)
    modllist = []
    modllist1 = []
    modllist2 = []
    res = u_tensor(d, l, p)
    for i in range(R):
        tl = torch.einsum('a, b, c, d->abcd', u4[:, i], u1[:, i], u2[:, i], u3[:, i])
        modl = torch.einsum('a, b->ab', u1[:, i], u2[:, i])
        if i == 0:
            llist = tl
        else:
            llist = llist + tl
        modllist.append(modl)
        modllist1.append(u3[:, i])
        modllist2.append(u4[:, i])
    plist = torch.randn(K, p[0], p[1], p[2])
    for i in range(K):
        if i == 0:
            su = kronecker(plist[i, :, :, :], llist[i, :, :, :])
        else:
            su = su + kronecker(plist[i, :, :, :], llist[i, :, :, :])
    w = torch.einsum('abc, da, eb, fc->def', su, res[0], res[1], res[2])
    frow = math.sqrt(torch.sum(w**2))
    w = w/frow
    return (w, llist, plist, modllist, modllist1, modllist2, frow)


def gen_y(w, x):
    return torch.sum(x*w)+torch.randn(1)


def gen_y1(w, x):
    return torch.sum(x*w)#+torch.randn(1)


def rmse(A_est, A_true):
    return math.sqrt(torch.mean((A_est - A_true)**2))


def distanceF(A_est, A_true):
    return math.sqrt(torch.sum((A_est - A_true)**2))


def netW(net, d, l, p, R, K):
    a1 = net.conv1.weight
    a2 = net.conv2.weight
    la = []
    for i in range(R):
        resa = torch.einsum('ab,c->abc',a2[i,0,:,:],a1[i,:,0,0])
        la.append(resa)
    ra1 = torch.stack(la, dim=0)
    a3 = net.conv3.weight
    a3 = torch.squeeze(torch.squeeze(a3, dim=2), dim=2)
    ra2 = torch.einsum('ab, bcde->acde', a3, ra1)
    b = net.fc1.weight.view(K, p[0], p[1], p[2])
    for i in range(K):
        if i == 0:
            rr = kronecker(b[i,:,:,:],ra2[i,:,:,:])
        else:
            rr = rr + kronecker(b[i,:,:,:],ra2[i,:,:,:])
    res = u_tensor(d, l, p)
    w = torch.einsum('abc, da, eb, fc->def', rr, res[0], res[1], res[2])
    return w



class gen_dataset(Dataset):

    def __init__(self, d, w, sample_size):
        self.X = []
        self.y = []
        for item in range(sample_size):
            x_train = torch.randn(d[2], d[0], d[1])
            self.X.append(x_train)
            x_train = x_train.permute(1, 2, 0)
            y_train = gen_y(w, x_train)
            self.y.append(y_train)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

    def __len__(self):
        return len(self.X)


class gen_datasetval(Dataset):

    def __init__(self, d, w, sample_size):
        self.X = []
        self.y = []
        for item in range(sample_size):
            x_train = torch.randn(d[2], d[0], d[1])
            self.X.append(x_train)
            x_train = x_train.permute(1, 2, 0)
            y_train = gen_y1(w, x_train)
            self.y.append(y_train)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

    def __len__(self):
        return len(self.X)


d = [32, 32, 16]
l = [8, 8, 16]
p = [5, 5, 1]
r = 8
K = 8  # K = [8,16,24,32]

sam_list = [4000, 3500, 3000, 2500, 2000, 1500, 1000]

ee = []
pp = []
tt = []
from torch.autograd import Variable
import torch.optim as optim

start_time = perf_counter()
for element in sam_list:
    sample_size = element
    w, llist, plist, modllist, modllist1, modllist2, frow = gen_w(d, l, p, r, K)
    batch_size = sample_size
    names = {}
    names["estError"] = []
    names["predError"] = []
    names["WerrFro"] = []
    names["testerror"] = []
    lr1 = 0.02
    for iter in tqdm(range(50)):
        dataset = gen_dataset(d, w, sample_size)
        dataset = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        net = ineffCNN(32, 32, 8, 8, 5, 5, 8, 16, K)
        net.cuda()
        criterion = nn.MSELoss()
        optimizer = optim.SGD(net.parameters(), lr=lr1, momentum=0.9)
        loss_last = 0
        loss_new = 1000
        i = 0
        while abs(loss_last - loss_new) > 0.00001:
            if i > 0:
                loss_last = loss_new
            for ix, (_x, _y) in enumerate(dataset):
                _x = torch.unsqueeze(Variable(_x).float(), 1)
                _x = torch.squeeze(Variable(_x).float(), 1)
                _x = _x.cuda()
                _y = torch.squeeze(Variable(_y).float())
                _y = _y.cuda()
                yhat = torch.squeeze(net(_x).float())
                loss = criterion(yhat, _y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                loss_new = loss.item()
                # print(loss_new)
            i = i + 1

        valdata = gen_datasetval(d, w, 300)
        valdata = DataLoader(valdata, batch_size=300, shuffle=False)
        for ix, (_x1, _y1) in enumerate(valdata):
            _x1 = torch.unsqueeze(Variable(_x1).float(), 1)
            _x1 = torch.squeeze(Variable(_x1).float(), 1)
            X_val = _x1
            y_val = _y1
        X_val = Variable(X_val).float()
        # X_F = X_F.to(device)
        X_val = X_val.cuda()
        y_val = torch.squeeze(Variable(y_val).float())
        y_val = y_val.cuda()
        # y_F = y_F.to(device)
        estfro = distanceF(netW(net.cpu(), d, l, p, r, K), w)
        net.cuda()
        print("Frobenius error of W is {}.".format(estfro))
        predError = rmse(torch.squeeze(net(X_val).float()), y_val)
        print("PredError is {}.".format(predError))
        names["predError"].append(predError)
        names["WerrFro"].append(estfro)

        testdata = gen_dataset(d, w, 300)
        testdata = DataLoader(testdata, batch_size=300, shuffle=False)
        for ix, (_x1, _y1) in enumerate(testdata):
            _x1 = torch.unsqueeze(Variable(_x1).float(), 1)
            _x1 = torch.squeeze(Variable(_x1).float(), 1)
            X_val = _x1
            y_val = _y1
        X_val = Variable(X_val).float()
        # X_F = X_F.to(device)
        X_val = X_val.cuda()
        y_val = torch.squeeze(Variable(y_val).float())
        y_val = y_val.cuda()
        testerr = rmse(torch.squeeze(net(X_val).float()), y_val)
        print("Testerror is {}.".format(testerr))
        names["testerror"].append(testerr)
        if iter % 50 == 0:
            print("aa")

    ee.append(names["WerrFro"])
    pp.append(names["predError"])
    tt.append(names["testerror"])

stop_time = perf_counter()
print(stop_time - start_time)



l = [ee, pp, tt]

with open('thm2k=8.pkl','wb') as f:
    pickle.dump(l,f)
