import math
import os
from datetime import datetime

import torch.nn as nn
import torch
from utils import randomize_locations, run_concurrent
from ada_hessian import AdaHessian


HIDDEN_SIZE = 8
OUTPUT_SIZE = 4
INPUT_SIZE = 6
SAMPLES = INPUT_SIZE
WEIGHT_VARS = OUTPUT_SIZE * SAMPLES
G = 1
TRIES = 400

EPOCHS = int(1e7)
LOSS_SUCCESS_THRESHOLD = 1e-2
LOSS_FAILURE_THRESHOLD = 1e-8
LOSS_THRESHOLD_EPOCHS = 10000


def time_now():
    return str(datetime.now()).replace(' ', '_').replace(':', '-')


LOG_DIR = os.path.join(os.path.dirname(__file__), 'log', time_now())


class Net(nn.Module):
    def __init__(self, device):
        super(Net, self).__init__()
        self.w = self._linear(HIDDEN_SIZE, INPUT_SIZE, grad=True, device=device)
        self.relu = nn.ReLU()
        self.d = self._linear(OUTPUT_SIZE, HIDDEN_SIZE, grad=False, device=device)

    def _linear(self, *sizes, grad=False, device='cpu'):
        data = rand(*sizes, order=sizes[1], device=device)
        return nn.Parameter(data, requires_grad=grad)

    def forward(self, x):
        return self.d @ self.relu(self.w @ x)

    def mask_grad(self, mask):
        self.w.grad = torch.masked_fill(self.w.grad, mask, 0)


def rand(*size, order=1, device='cpu'):
    return torch.normal(0.0, math.sqrt(G / order), size).to(device)


def create_labels(device):
    x = rand(INPUT_SIZE, SAMPLES, device=device)
    with torch.no_grad():
        y = Net(device)(x)
    return x, y


def get_device(thread_num):
    if not torch.cuda.is_available():
        return 'cpu'
    return f'cuda:{thread_num % torch.cuda.device_count()}'


def create_mask(k, device):
    var, _ = randomize_locations(INPUT_SIZE, k, WEIGHT_VARS, max_per_row=SAMPLES)
    mask = torch.full((HIDDEN_SIZE, INPUT_SIZE), True).to(device)
    mask[list(zip(*var))] = False
    return mask


def run_epoch(x, y, model, optimizer, freeze_mask, criterion):
    preds = model(x)
    loss = criterion(preds, y)

    optimizer.zero_grad()
    loss.backward(create_graph=True)  # create_graph=True needed for AdaHessian
    model.mask_grad(freeze_mask)

    optimizer.step()
    return loss.item() / y.shape[0]


def train(k, i):
    device = get_device(i)
    model = Net(device)
    mask = create_mask(k, device)
    model.to(device)

    x, y = create_labels(device)

    optimizer = AdaHessian(model.parameters())
    criterion = nn.MSELoss()

    min_loss_by_epoch = [float('inf')]  # Initialize with inf for comparison
    with open(os.path.join(LOG_DIR, f'k{k}-{time_now()}'), 'w') as log_file:
        for epoch in range(EPOCHS):
            loss = run_epoch(x, y, model, optimizer, mask, criterion)
            min_loss_by_epoch.append(min(loss, min_loss_by_epoch[-1]))

            log_file.write(f'epoch {epoch}, loss {loss}, min by now {min_loss_by_epoch[-1]}{os.linesep}')

            if min_loss_by_epoch[-1] < LOSS_SUCCESS_THRESHOLD:
                print(f'k={k} success in {epoch} round with loss={min_loss_by_epoch[-1]}')
                return epoch

            if (epoch >= LOSS_THRESHOLD_EPOCHS and
                    min_loss_by_epoch[-LOSS_THRESHOLD_EPOCHS] - min_loss_by_epoch[-1] < LOSS_FAILURE_THRESHOLD):
                print(f'k={k} failure in {epoch} round with loss={min_loss_by_epoch[-1]}')
                return None

    print(f'k={k} failure by rounds with loss={min_loss_by_epoch[-1]}')
    return None


def main():
    os.makedirs(LOG_DIR)
    print(
        f'running with n={HIDDEN_SIZE}, d={OUTPUT_SIZE}, b={INPUT_SIZE}, g={G}, m={SAMPLES}, {TRIES} tries', flush=True)
    min_k = math.ceil(WEIGHT_VARS / INPUT_SIZE)
    ks = range(min_k, HIDDEN_SIZE + 1)
    run_concurrent(ks, TRIES, train)


if __name__ == '__main__':
    main()