import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
import math
from tqdm import tqdm
from time import perf_counter
import pickle


class Tuckernet(nn.Module):

    def __init__(self, d1, d2, d3, l1, l2, l3, p1, p2, p3, r1, r2, r3, K):
        super(Tuckernet, self).__init__()
        self.d1 = d1
        self.d2 = d2
        self.d3 = d3
        self.l1 = l1
        self.l2 = l2
        self.l3 = l3
        self.p1 = p1
        self.p2 = p2
        self.p3 = p3
        self.r1 = r1
        self.r2 = r2
        self.r3 = r3
        self.K = K
        self.conv1 = nn.Conv3d(3, 1, kernel_size=(1, 1, 1), bias=False)
        self.convr1 = nn.ModuleList([nn.Conv3d(1, 1, kernel_size=(l1, 1, 1), bias=False) for i in range(r1)])
        self.convr2 = nn.ModuleList([nn.Conv3d(1, 1, kernel_size=(1, l2, 1), bias=False) for i in range(r2)])
        self.convr3 = nn.ModuleList([nn.Conv3d(1, 1, kernel_size=(1, 1, l3), bias=False) for i in range(r3)])
        self.g = torch.randn(K, r1*r2*r3, requires_grad=True).cuda()
        self.avg = nn.AvgPool3d((int((d1-l1+1)/p1), int((d2-l2+1)/p2), int((d3-l3+1)/p3)))
        self.fc1 = nn.Linear(K*p1*p2*p3, 1, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        ll = []
        for i in range(self.r3):
            for j in range(self.r2):
                for k in range(self.r1):
                    x1 = self.convr3[i](x)
                    x2 = self.convr2[j](x1)
                    x3 = self.convr1[k](x2)
                    x3 = torch.squeeze(x3, dim=1)
                    ll.append(x3)
        xx = torch.stack(ll, dim=0)
        ll = []
        for i in range(self.K):
            x = torch.einsum('abcde, a -> bcde', xx, self.g[i, :])
            ll.append(x)
        x = torch.stack(ll, dim=1)
        x = self.avg(x)
        x = x.reshape(-1, self.K*self.p1*self.p2*self.p3)
        x = self.fc1(x)
        return x


def kronecker(a, b):
    return torch.ger(a.view(-1), b.view(-1)).reshape(*(a.size() + b.size())).permute([0, 3, 1, 4, 2, 5]).reshape(a.size(0)*b.size(0),a.size(1)*b.size(1),a.size(2)*b.size(2))


def u_tensor1(i, l, d):
    return torch.cat((torch.zeros((i-1, l), dtype=torch.float32), torch.eye(l, dtype=torch.float32), torch.zeros((d-l-i+1, l), dtype=torch.float32)), dim=0)


def u_tensor(d, l, p):
    res = []
    for i in range(3):
        u1 = []
        for j in range(1, d[i]-l[i]+2):
            u1.append(u_tensor1(j, l[i], d[i]))
        u2 = []
        for k in range(p[i]):
            q = int((d[i]-l[i]+1)/p[i])
            u2.append(sum(u1[k*q:(k+1)*q])/q)
        res.append(torch.cat(u2, dim=1))
    return res


def gen_w(d, l, p, K, r):
    g = torch.randn(r[0], r[1], r[2], 1, K)
    H1 = torch.randn(l[0], r[0])
    H2 = torch.randn(l[1], r[1])
    H3 = torch.randn(l[2], r[2])
    H4 = torch.randn(3, 1)
    A = torch.einsum('abcde, fa, gb, hc, id -> fghie', g, H1, H2, H3, H4)
    B = torch.randn(p[0], p[1], p[2], K)
    res = u_tensor(d, l, p)
    ll = []
    for i in range(3):
        for j in range(K):
            if j==0:
                ab = kronecker(B[:,:,:,j], A[:,:,:,i,j])
            else:
                ab = ab + kronecker(B[:,:,:,j], A[:,:,:,i,j])
        rr = torch.einsum('abc, da, eb, fc -> def', ab, res[0], res[1], res[2])
        ll.append(rr)
    w = torch.stack(ll, dim=0)
    return (w,g,H1,H2,H3,H4,B)


def gen_y(w, x):
    return torch.sum(x*w)+torch.randn(1)


def gen_y1(w, x):
    return torch.sum(x*w)#+torch.randn(1)


def rmse(A_est, A_true):
    return math.sqrt(torch.mean((A_est - A_true)**2))


def distanceF(A_est, A_true):
    return math.sqrt(torch.sum((A_est - A_true)**2))


def netW(net, d, l, p, r, K):
    g = net.g
    g = torch.transpose(g, 0, 1)
    g = g.view(r[2], r[1], r[0], 1, K)
    g = g.permute(2,1,0,3,4)
    ll = []
    for i in range(r[0]):
        medH = net.convr1[i].weight
        medH = medH.view(l[0])
        ll.append(medH)
    H1 = torch.stack(ll, 1)
    ll = []
    for i in range(r[1]):
        medH = net.convr2[i].weight
        medH = medH.view(l[1])
        ll.append(medH)
    H2 = torch.stack(ll, 1)
    ll = []
    for i in range(r[2]):
        medH = net.convr3[i].weight
        medH = medH.view(l[2])
        ll.append(medH)
    H3 = torch.stack(ll, 1)
    H4 = net.conv1.weight
    H4 = H4.view(3, 1)
    g = g.cpu()
    A = torch.einsum('abcde, fa, gb, hc, id -> fghie', g, H1, H2, H3, H4)
    B = net.fc1.weight.view(K, p[0], p[1], p[2])
    B = B.permute(1, 2, 3, 0)
    res = u_tensor(d, l, p)
    ll = []
    for i in range(3):
        for j in range(K):
            if j==0:
                ab = kronecker(B[:,:,:,j], A[:,:,:,i,j])
            else:
                ab = ab + kronecker(B[:,:,:,j], A[:,:,:,i,j])
        rr = torch.einsum('abc, da, eb, fc -> def', ab, res[0], res[1], res[2])
        ll.append(rr)
    w = torch.stack(ll, dim=0)
    return w


def initnet(net, g,H1,H2,H3,H4,B,r,l,K,p):
    g = g
    g = g.permute(2,1,0,3,4)
    g = g.reshape(r[0]*r[1]*r[2],K)
    g = torch.transpose(g, 0,1)
    net.g = torch.nn.Parameter(g)
    H1 = H1
    H2 = H2
    H3 = H3
    H4 = H4
    for i in range(r[0]):
        net.convr1[i].weight = torch.nn.Parameter(H1[:,i].view(1,1,l[0],1,1))
    for i in range(r[1]):
        net.convr2[i].weight = torch.nn.Parameter(H2[:,i].view(1,1,1,l[1],1))
    for i in range(r[2]):
        net.convr3[i].weight = torch.nn.Parameter(H3[:,i].view(1,1,1,1,l[2]))
    net.conv1.weight = torch.nn.Parameter(H4[:,0].view(1,3,1,1,1))
    B = B.permute(3,0,1,2)
    net.fc1.weight = torch.nn.Parameter(B.reshape(1,-1))
    return net


class gen_dataset(Dataset):

    def __init__(self, d, w, sample_size):
        self.X = []
        self.y = []
        for item in range(sample_size):
            x_train = torch.randn(3, d[0], d[1], d[2])
            self.X.append(x_train)
            y_train = gen_y(w, x_train)
            self.y.append(y_train)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

    def __len__(self):
        return len(self.X)


class gen_datasetval(Dataset):

    def __init__(self, d, w, sample_size):
        self.X = []
        self.y = []
        for item in range(sample_size):
            x_train = torch.randn(3, d[0], d[1], d[2])
            self.X.append(x_train)
            y_train = gen_y1(w, x_train)
            self.y.append(y_train)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

    def __len__(self):
        return len(self.X)


dlist = [10, 10, 8]
llist = [5, 5, 3]
plist = [2, 2, 2]
r = [1, 1, 2]
K = 3

sam_list = [2000, 950, 500, 320, 230, 180, 150]

ee = []
pp = []
tt = []
from torch.autograd import Variable
import torch.optim as optim

start_time = perf_counter()
for element in sam_list:
    sample_size = element
    w, g, H1, H2, H3, H4, B = gen_w(dlist, llist, plist, K, r)
    batch_size = sample_size
    names = {}
    names["estError"] = []
    names["predError"] = []
    names["WerrFro"] = []
    names["testerror"] = []
    lr1 = 0.0001
    for iter in tqdm(range(200)):
        dataset = gen_dataset(dlist, w, sample_size)
        dataset = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        net = Tuckernet(10, 10, 8, 5, 5, 3, 2, 2, 2, 1, 1, 2, 3)
        net = initnet(net, g, H1, H2, H3, H4, B, r, llist, K, plist)
        net.cuda()
        criterion = nn.MSELoss()
        optimizer = optim.SGD(net.parameters(), lr=lr1, momentum=0.9)
        loss_last = 0
        loss_new = 1000
        i = 0
        while abs(loss_last - loss_new) > 0.000000001:
            if i > 0:
                loss_last = loss_new
            for ix, (_x, _y) in enumerate(dataset):
                _x = torch.unsqueeze(Variable(_x).float(), 1)
                _x = torch.squeeze(Variable(_x).float(), 1)
                _x = _x.cuda()
                _y = torch.squeeze(Variable(_y).float())
                _y = _y.cuda()
                yhat = torch.squeeze(net(_x).float())
                loss = criterion(yhat, _y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                loss_new = loss.item()
                # print(loss_new)
            i = i + 1

        valdata = gen_datasetval(dlist, w, 200)
        valdata = DataLoader(valdata, batch_size=200, shuffle=False)
        for ix, (_x1, _y1) in enumerate(valdata):
            _x1 = torch.unsqueeze(Variable(_x1).float(), 1)
            _x1 = torch.squeeze(Variable(_x1).float(), 1)
            X_val = _x1
            y_val = _y1
        X_val = Variable(X_val).float()
        # X_F = X_F.to(device)
        X_val = X_val.cuda()
        y_val = torch.squeeze(Variable(y_val).float())
        y_val = y_val.cuda()
        # y_F = y_F.to(device)
        estfro = distanceF(netW(net.cpu(), dlist, llist, plist, r, K), w)
        net.cuda()
        print("Frobenius error of W is {}.".format(estfro))
        predError = rmse(torch.squeeze(net(X_val).float()), y_val)
        print("PredError is {}.".format(predError))
        names["predError"].append(predError)
        names["WerrFro"].append(estfro)

        testdata = gen_dataset(dlist, w, 200)
        testdata = DataLoader(testdata, batch_size=200, shuffle=False)
        for ix, (_x1, _y1) in enumerate(testdata):
            _x1 = torch.unsqueeze(Variable(_x1).float(), 1)
            _x1 = torch.squeeze(Variable(_x1).float(), 1)
            X_val = _x1
            y_val = _y1
        X_val = Variable(X_val).float()
        # X_F = X_F.to(device)
        X_val = X_val.cuda()
        y_val = torch.squeeze(Variable(y_val).float())
        y_val = y_val.cuda()
        testerr = rmse(torch.squeeze(net(X_val).float()), y_val)
        print("Testerror is {}.".format(testerr))
        names["testerror"].append(testerr)
        if iter % 50 == 0:
            print("aa")

    ee.append(names["WerrFro"])
    pp.append(names["predError"])
    tt.append(names["testerror"])

stop_time = perf_counter()
print(stop_time - start_time)

l = [ee, pp, tt]

with open('tucker.pkl','wb') as f:
    pickle.dump(l,f)
