import os
import pathlib
import torch
import torchvision

from . import common


_here = pathlib.Path(__file__).resolve().parent


def generate_weights_data():
    train_dataset = torchvision.datasets.MNIST(_here / 'data' / 'MNIST', download=True,
                                               transform=torchvision.transforms.ToTensor())
    train_dataloader = common.dataloader(train_dataset, batch_size=2500)
    model = torch.nn.Sequential(torch.nn.Conv2d(1, 8, 5),
                                torch.nn.ReLU(),
                                torch.nn.Conv2d(8, 8, 5),
                                torch.nn.ReLU(),
                                torch.nn.Flatten(),
                                torch.nn.Linear(3200, 10)).to('cuda')
    optimiser = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
    data = []
    for epoch in range(100):
        for x, y in train_dataloader:
            x = x.to('cuda')
            y = y.to('cuda')
            pred_y = model(x)
            loss = torch.nn.functional.cross_entropy(pred_y, y)
            loss.backward()
            optimiser.step()
            optimiser.zero_grad()

        with torch.no_grad():
            datai = []
            for parameter in model.parameters():
                datai.append(parameter.flatten())
            data.append(torch.cat(datai))

        # total_loss = 0.
        # total_size = 0
        # total_correct = 0
        # for x, y in train_dataloader:
        #     x = x.to('cuda')
        #     y = y.to('cuda')
        #     with torch.no_grad():
        #         pred_y = model(x)
        #     total_loss += torch.nn.functional.cross_entropy(pred_y, y).item()
        #     total_correct += (pred_y.argmax(dim=1) == y).sum().item()
        #     total_size += x.size(0)
        # print(total_loss / total_size, total_correct / total_size)

    data = torch.stack(data)  # shape (length, parameters)
    data = data.transpose(0, 1).unsqueeze(-1)  # shape (batch=parameters, length=100, channels=1)
    return data


def get_weights_data():
    data_folder = _here / 'processed_data' / 'weights'
    file = data_folder / 'weights.pt'
    if os.path.exists(file):
        data = torch.load(file)
        labels = torch.load(data_folder / 'labels.pt')
    else:
        os.makedirs(data_folder, exist_ok=True)
        data = torch.cat([generate_weights_data() for _ in range(10)]).cpu()
        labels = torch.empty(data.size(0), 0, dtype=data.dtype)
        torch.save(data, file)
        torch.save(labels, data_folder / 'labels.pt')
    return data, labels


def weights_data(batch_size):
    t = torch.linspace(0, 99, 100)
    dataset = torch.utils.data.TensorDataset(*get_weights_data())
    dataloader = common.dataloader(dataset, batch_size=batch_size)
    input_channels = 1
    label_channels = 0
    return t, dataloader, input_channels, label_channels
