import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
import math
from tqdm import tqdm
from time import perf_counter
import pickle


class CnnTensor(nn.Module):

    def __init__(self, d1, d2, d3, l1, l2, l3, p1, p2, p3):
        super(CnnTensor, self).__init__()
        self.d1 = d1
        self.d2 = d2
        self.d3 = d3
        self.l1 = l1
        self.l2 = l2
        self.l3 = l3
        self.p1 = p1
        self.p2 = p2
        self.p3 = p3
        self.conv1 = nn.Conv3d(1, 1, kernel_size=(l1, l2, l3), bias=False)
        self.avg = nn.AvgPool3d((int((d1-l1+1)/p1), int((d2-l2+1)/p2), int((d3-l3+1)/p3)))
        self.fc1 = nn.Linear(1*p1*p2*p3, 1, bias=False)

    def forward(self, x):
        x = self.conv1(x)
        x = self.avg(x)
        x = x.reshape(-1, 1*self.p1*self.p2*self.p3)
        x = self.fc1(x)
        return x


def kronecker(a, b):
    return torch.ger(a.view(-1), b.view(-1)).reshape(*(a.size() + b.size())).permute([0, 3, 1, 4, 2, 5]).reshape(
        a.size(0) * b.size(0), a.size(1) * b.size(1), a.size(2) * b.size(2))


def u_tensor1(i, l, d):
    return torch.cat((torch.zeros((i - 1, l), dtype=torch.float32), torch.eye(l, dtype=torch.float32),
                      torch.zeros((d - l - i + 1, l), dtype=torch.float32)), dim=0)


def u_tensor(d, l, p):
    res = []
    for i in range(3):
        u1 = []
        for j in range(1, d[i] - l[i] + 2):
            u1.append(u_tensor1(j, l[i], d[i]))
        u2 = []
        for k in range(p[i]):
            q = int((d[i] - l[i] + 1) / p[i])
            u2.append(sum(u1[k * q:(k + 1) * q]) / q)
        res.append(torch.cat(u2, dim=1))
    return res


def gen_w(d, l, p):
    a = torch.randn(l[0], l[1], l[2])
    b = torch.randn(p[0], p[1], p[2])
    res = u_tensor(d, l, p)
    w = torch.einsum('abc, da, eb, fc->def', kronecker(b, a), res[0], res[1], res[2])
    return w


def gen_y(w, x):
    return torch.sum(x * w) + torch.randn(1)


def gen_y1(w, x):
    return torch.sum(x * w)  # +torch.randn(1)


def rmse(A_est, A_true):
    return math.sqrt(torch.mean((A_est - A_true) ** 2))


def distanceF(A_est, A_true):
    return math.sqrt(torch.sum((A_est - A_true) ** 2))


def netW(net, d, l, p):
    a = net.conv1.weight[0, 0, :, :, :]
    rb = net.fc1.weight.view(1, p[0], p[1], p[2])
    b = rb[0, :, :, :]
    res = u_tensor(d, l, p)
    return torch.einsum('abc, da, eb, fc->def', kronecker(b, a), res[0], res[1], res[2])


class gen_dataset(Dataset):

    def __init__(self, d, w, sample_size):
        self.X = []
        self.y = []
        for item in range(sample_size):
            x_train = torch.randn(d[0], d[1], d[2])
            self.X.append(x_train)
            y_train = gen_y(w, x_train)
            self.y.append(y_train)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

    def __len__(self):
        return len(self.X)


class gen_datasetval(Dataset):

    def __init__(self, d, w, sample_size):
        self.X = []
        self.y = []
        for item in range(sample_size):
            x_train = torch.randn(d[0], d[1], d[2])
            self.X.append(x_train)
            y_train = gen_y1(w, x_train)
            self.y.append(y_train)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

    def __len__(self):
        return len(self.X)


dlist = [7, 5, 7]
llist = [2, 2, 2]
plist = [3, 2, 3]
sam_list = [1200, 560, 300, 200, 130, 100, 75]

ee = []
pp = []
from torch.autograd import Variable
import torch.optim as optim

start_time = perf_counter()
for element in sam_list:
    sample_size = element
    w = gen_w(dlist, llist, plist)
    batch_size = sample_size
    names = {}
    names["estError"] = []
    names["predError"] = []
    names["WerrFro"] = []
    lr1 = 0.01
    for iter in tqdm(range(200)):
        dataset = gen_dataset(dlist, w, sample_size)
        dataset = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        net = CnnTensor(7, 5, 7, 2, 2, 2, 3, 2, 3)
        net.cuda()
        criterion = nn.MSELoss()
        optimizer = optim.SGD(net.parameters(), lr=lr1, momentum=0.9)
        loss_last = 0
        loss_new = 1000
        i = 0
        while abs(loss_last - loss_new) > 0.000000001:
            if i > 0:
                loss_last = loss_new
            for ix, (_x, _y) in enumerate(dataset):
                _x = torch.unsqueeze(Variable(_x).float(), 1)
                _x = _x.cuda()
                _y = torch.squeeze(Variable(_y).float())
                _y = _y.cuda()
                yhat = torch.squeeze(net(_x).float())
                loss = criterion(yhat, _y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                loss_new = loss.item()
                # print(loss_new)
            i = i + 1

        valdata = gen_datasetval(dlist, w, 100)
        valdata = DataLoader(valdata, batch_size=100, shuffle=False)
        for ix, (_x1, _y1) in enumerate(valdata):
            _x1 = torch.unsqueeze(Variable(_x1).float(), 1)
            X_val = _x1
            y_val = _y1
        X_val = Variable(X_val).float()
        X_val = X_val.cuda()
        y_val = torch.squeeze(Variable(y_val).float())
        y_val = y_val.cuda()
        estfro = distanceF(netW(net.cpu(), dlist, llist, plist), w)
        net.cuda()
        print("Frobenius error of W is {}.".format(estfro))
        predError = rmse(torch.squeeze(net(X_val).float()), y_val)
        print("PredError is {}.".format(predError))
        names["predError"].append(predError)
        names["WerrFro"].append(estfro)
        if iter % 50 == 0:
            print("aa")

    ee.append(names["WerrFro"])
    pp.append(names["predError"])

stop_time = perf_counter()
print(stop_time - start_time)

l = [ee, pp]

with open('time_dep.pkl','wb') as f:
    pickle.dump(l,f)
