import torch
import torch.nn as nn
import torch.optim as optim

import math

from dataset import read_dataset
from model import SimpleLinear

def training_setup(X, lr, factor, patience):
    model = SimpleLinear(X.shape[1])
    optimizer = optim.Adam(model.parameters(),lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=factor, patience=patience)

    return model, optimizer, scheduler

def train(model, optimizer, scheduler, X, y, epochs, batch_size, reg):
    model.train()

    for epoch in range(epochs):
        inds = torch.randperm(X.shape[0])
        X = X[inds]
        y = y[inds]

        total_loss = 0
        for i in range(math.ceil(X.shape[0] / batch_size)):
            data = X[i*batch_size:(i+1)*batch_size]
            targets = y[i*batch_size:(i+1)*batch_size]

            optimizer.zero_grad()

            if reg is None:
                loss = model.loss(data, targets)
            else:
                loss = model.log_loss(data, targets, reg)
            
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        total_loss /= i+1

        scheduler.step(total_loss)

        # if (epoch+1) % 10 == 0:
        #     print(total_loss, scheduler.get_last_lr())
    
    return total_loss
