# finetune.py
import torch
from torch import nn, optim
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
import logging
from resnet import ResNet18
from vgg import vgg16

def setup_logger(log_path):
    logging.basicConfig(
        filename=log_path,
        filemode='w',
        format='%(asctime)s - %(message)s',
        level=logging.INFO
    )

def train_vgg(model, dataloader, device, num_epochs=200, lr=0.1, verbose=True):
    model = vgg16(num_classes=200).to(device)
    num_epochs = 10
    model.train()
    optimizer = optim.SGD(model.parameters(), lr=lr)
    scheduler = MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)

    for epoch in range(num_epochs):
        running_loss, total, correct = 0.0, 0, 0
        for inputs, labels, _ in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device).long()        

            optimizer.zero_grad()
            logits = model(inputs)                   

            loss = F.cross_entropy(logits, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, preds = logits.max(1)
            correct += preds.eq(labels).sum().item()
            total += inputs.size(0)

        scheduler.step()
        epoch_loss = running_loss / total
        epoch_acc = correct / total
        if verbose and ((epoch + 1) % 10 == 0 or epoch == num_epochs - 1):
            log_msg = f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f} - Acc: {epoch_acc:.4f} - LR: {scheduler.get_last_lr()[0]:.5f}"
            print(log_msg)
            logging.info(log_msg)
    return model


def train_resnet(model, mode, dataloader, device, num_epochs=40, lr=0.1, verbose=True):
    if mode == 'scratch':
        model = ResNet18()
        model = model.to(device)
        num_epochs = 200
        model.train()
        optimizer = optim.SGD(model.parameters(), lr=lr)
        scheduler = MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)

    else: # mode == 'transfer'
        num_epochs = 40
        model.train()
        optimizer = optim.SGD(model.parameters(), lr=lr)
        scheduler = MultiStepLR(optimizer, milestones=[25, 35], gamma=0.1)

    for epoch in range(num_epochs):
        running_loss, total, correct = 0.0, 0, 0
        for inputs, labels, _ in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            logits = model(inputs)
            loss = F.cross_entropy(logits, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            _, preds = logits.max(1)
            correct += preds.eq(labels).sum().item()
            total += inputs.size(0)
            
        scheduler.step()
        epoch_loss = running_loss / total
        epoch_acc = correct / total
        if verbose and ((epoch + 1) % 10 == 0 or epoch == num_epochs - 1):
            log_msg = f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f} - Acc: {epoch_acc:.4f} - LR: {scheduler.get_last_lr()[0]:.5f}"
            print(log_msg)
            logging.info(log_msg)
    return model