import math
import os
from datetime import datetime
from functools import partial
import random

import torch.nn as nn
import torch
from torch.optim import Adam
from utils import run_concurrent
from torchvision.datasets import MNIST
from torchvision import transforms


PICTURE_SIZE = 16
INPUT_SIZE = PICTURE_SIZE ** 2
OUTPUT_SIZE = 10
HIDDEN_SIZE = 1000
SAMPLES = 1000
WEIGHT_VARS = SAMPLES * OUTPUT_SIZE
G = 1
TRIES = 400

EPOCHS = int(1e7)
LOSS_FAILURE_THRESHOLD = 1e-8
LOSS_THRESHOLD_EPOCHS = 10000


def time_now():
    return str(datetime.now()).replace(' ', '_').replace(':', '-')


LOG_DIR = os.path.join(os.path.dirname(__file__), 'log', time_now())


class Net(nn.Module):
    def __init__(self, device):
        super(Net, self).__init__()
        self.w = self._linear(INPUT_SIZE, HIDDEN_SIZE, grad=True, device=device)
        self.relu = nn.ReLU()
        self.d = self._linear(HIDDEN_SIZE, OUTPUT_SIZE, grad=False, device=device)

    def _linear(self, *sizes, grad=False, device='cpu'):
        data = rand(*sizes, order=sizes[1], device=device)
        return nn.Parameter(data, requires_grad=grad)

    def forward(self, x):
        return self.relu(x @ self.w) @ self.d

    def mask_grad(self, mask):
        self.w.grad = torch.masked_fill(self.w.grad, mask, 0)


def rand(*size, order=1, device='cpu'):
    return torch.normal(0.0, math.sqrt(G / order), size).to(device)


def get_device(thread_num):
    if not torch.cuda.is_available():
        return 'cpu'
    return f'cuda:{thread_num % torch.cuda.device_count()}'


def create_mask(k, device):
    mask = torch.full((INPUT_SIZE, HIDDEN_SIZE), True).to(device)
    per_row = WEIGHT_VARS // k
    left = WEIGHT_VARS
    for i in range(k):
        in_this_row = min(per_row, left)
        left -= in_this_row
        mask[random.sample(range(INPUT_SIZE), k=in_this_row), i] = False
    return mask


def correct_count(preds, y):
    return(torch.tensor(torch.sum(preds == y).item()))


def run_epoch(x, y, model, optimizer, freeze_mask, criterion):
    probs = model(x)
    preds = torch.max(probs, 1)[1]
    loss = criterion(probs, y)
    correct = correct_count(preds, y)

    optimizer.zero_grad()
    loss.backward()
    model.mask_grad(freeze_mask)

    optimizer.step()
    return loss.item() / y.shape[0], correct


def train(k, i, x, y):
    device = get_device(i)
    model = Net(device)
    mask = create_mask(k, device)
    model.to(device)
    x = x.to(device)
    y = y.to(device)

    optimizer = Adam(model.parameters())
    criterion = nn.CrossEntropyLoss()

    min_loss_by_epoch = [float('inf')]  # Initialize with inf for comparison
    with open(os.path.join(LOG_DIR, f'k{k}-{time_now()}'), 'w') as log_file:
        for epoch in range(EPOCHS):
            loss, correct = run_epoch(x, y, model, optimizer, mask, criterion)
            min_loss_by_epoch.append(min(loss, min_loss_by_epoch[-1]))

            log_file.write(f'epoch {epoch}, loss {loss}, min by now {min_loss_by_epoch[-1]}{os.linesep}, correct {correct}')

            if correct == SAMPLES:
                print(f'k={k} success in {epoch} round with loss={min_loss_by_epoch[-1]}')
                return epoch

            if (epoch >= LOSS_THRESHOLD_EPOCHS and
                    min_loss_by_epoch[-LOSS_THRESHOLD_EPOCHS] - min_loss_by_epoch[-1] < LOSS_FAILURE_THRESHOLD):
                print(f'k={k} failure in {epoch} round with loss={min_loss_by_epoch[-1]}')
                return None

    print(f'k={k} failure by rounds with loss={min_loss_by_epoch[-1]}')
    return None


def main():
    root = os.path.join(os.path.dirname(__file__), 'data')
    transform = transforms.Compose([transforms.Resize(PICTURE_SIZE), transforms.ToTensor()])
    mnist_dataset = MNIST(root=root, train=True, transform=transform)
    data = [mnist_dataset[i] for i in random.sample(range(len(mnist_dataset)), k=SAMPLES)]
    x = torch.cat([d[0] for d in data]).reshape(SAMPLES, INPUT_SIZE)
    y = torch.tensor([d[1] for d in data], dtype=torch.long)

    os.makedirs(LOG_DIR)
    print(
        f'running with n={HIDDEN_SIZE}, g={G}, m={SAMPLES}, b={INPUT_SIZE}, {TRIES} tries', flush=True)
    min_k = math.ceil(WEIGHT_VARS / INPUT_SIZE)
    ks = list(range(min_k, 100, 10)) + list(range(100, HIDDEN_SIZE + 1, 100))
    run_concurrent(ks, TRIES, partial(train, x=x, y=y))
    # train(HIDDEN_SIZE, 1, x, y)


if __name__ == '__main__':
    main()
