import numpy as np
import torch
import os
from models.utils import En_Decoding2
from tqdm import tqdm

def get_firing_rate(model):
    fr=[]
    for name,child in model.named_children():
        if isinstance(child, En_Decoding2):
            fr.append(child.batch_fire_rate)
        else:
            fr.extend(get_firing_rate(child))
    return np.array(fr)

def Test(test_loader, model, criterion, device, measure_fr:bool):
    '''
    Test for AEC.
    Returns: acc, loss, LASFR
    '''
    epoch_loss = 0
    tot = 0.
    model.eval()
    model.to(device)
    length = 0
    fr = np.array(0.)
    with torch.no_grad():
        for img, label in test_loader:
            img = img.to(device)
            label = label.to(device)
            out = model(img)
            epoch_loss += criterion(out, label).item()
            length += len(label)    
            tot += (label==out.max(1)[1]).sum().item()
            if measure_fr:
                fr = fr + get_firing_rate(model) #inplace(+=)计算不行, 必须重新赋值
    if measure_fr:
        return  tot/length, epoch_loss/len(test_loader), fr/length
    else:
        return tot/length, epoch_loss/len(test_loader)



def Train(train_dataloader, test_dataloader, model, optimizer, scheduler,criterion ,epochs, device,  save,  work_directory):
    model.to(device)
    
    best_acc = 0
    val_acc_list, val_loss_list, tr_acc_list, tr_loss_list=[], [], [], []

    for epoch in range(epochs):
        model.train()
        total, correct, epoch_loss= 0, 0, 0.
        for img, label in tqdm(train_dataloader,disable=True): #DEBUG
            img = img.to(device)
            label = label.to(device)
            
            out = model(img)
            loss = criterion(out, label)

            total+=len(label)
            correct+=(label==out.max(1)[1]).sum().item() 
            epoch_loss+=loss.item()
           
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        scheduler.step()
        #最后一个epoch练完之后，在测试集上测试LASFR
        tmp_acc, val_loss= Test(test_dataloader, model, criterion, device, False)

        val_acc_list.append(tmp_acc)
        val_loss_list.append(val_loss)
        tr_acc_list.append(correct/total)
        tr_loss_list.append(epoch_loss/len(train_dataloader))
        print(f'Epoch {epoch} : Val_loss: {val_loss:.5f}, Acc: {tmp_acc*100:.3f}%', flush=True)

        if save != None and tmp_acc >= best_acc:
            best_acc=tmp_acc
            os.makedirs(os.path.join(work_directory,save),exist_ok=True)
            torch.save(model.state_dict(), os.path.join(work_directory,save,'best_model.pth'))
        
    return np.array(val_acc_list), np.array(val_loss_list), np.array(tr_acc_list), np.array(tr_loss_list)

