from torch import nn
import torch
from tqdm import tqdm
from utils import *

def eval_ann(test_dataloader, model, epoch):
    epoch_loss = 0
    tot = 0
    lenth = 0
    loss_fn = nn.CrossEntropyLoss()
    model.eval()
    model.cuda()
    with torch.no_grad():
        for img, label in test_dataloader:
            img = img.cuda()
            label = label.cuda()
            out = model(img)
            loss = loss_fn(out, label)
            epoch_loss += loss.item()
            tot += (label==out.max(1)[1]).sum().item()
            lenth += len(img)
        print(f"Epoch {epoch}: Acc: {tot/lenth} Val_loss: {epoch_loss/lenth}")
    return tot/lenth

def train_ann(train_dataloader, test_dataloader, model, epochs, lr, wd, device):
    model = model.cuda(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(epochs):
        epoch_loss = 0
        model.train()
        for img, label in tqdm(train_dataloader):
            img = img.cuda(device)
            label = label.cuda(device)
            optimizer.zero_grad()
            out = model(img)
            loss = loss_fn(out, label)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        eval_ann(test_dataloader, model, epoch)
        scheduler.step()
    return model

def eval_snn(test_dataloader, model, sim_len, device):
    tot = [0] * sim_len
    lenth = 0
    model = model.cuda(device)
    model.eval()
    # valuate
    with torch.no_grad():
        for img, label in tqdm(test_dataloader):
            spikes = 0
            img = img.cuda(device)
            label = label.cuda(device)
            for t in range(sim_len):
                out = model(img)
                spikes += out
                tot[t] += (label==spikes.max(1)[1]).sum().item()
            lenth += len(img)
            reset_net(model)
    return tot