import numpy as np
from torch import nn
import torch
from tqdm import tqdm
import utils
import os
import torch.nn.functional as F
from modules import ScaledNeuron

def eval_ann(test_dataloader, model, loss_fn, device):
    epoch_loss = 0
    tot = 0.
    model.eval()
    model.to(device)
    length = 0
    with torch.no_grad():
        for img, label in test_dataloader:
            img = img.to(device)
            label = label.to(device)
            out = model(img)
            loss = loss_fn(out, label)
            epoch_loss += loss.item()
            length += len(label)    
            tot += (label==out.max(1)[1]).sum().item()
    return  tot/length, epoch_loss/len(test_dataloader)

def eval_snn(test_dataloader, model, device, sim_len, num_activations):
    '''
    return: a 1D array of accuracy and a 2D array [num_activations, T] of LASFR(average spike per data per neuron per time) 
    (accuracy,LASFR)
    accuracy[t] is accuracy for time step (t+1).LASFR[t] is firing rate till t time_step.
    '''
    tot = [0. for i in range(sim_len)]
    sfr_layer_list= [[0. for i in range(sim_len)] for j in range(num_activations)]

    length = 0
    model.to(device)
    model.eval()

    with torch.no_grad():
        for img, label in tqdm(test_dataloader,disable=False):
            spikes = 0
            length += len(label)
            img = img.to(device)
            label = label.to(device)
            for t in range(sim_len):
                out = model(img)
                spikes += out

                fr=get_firing_rate(model)
                assert len(fr)==num_activations
                for j,sfr in enumerate(fr):
                    sfr_layer_list[j][t]+=sfr
                tot[t] += ((label==spikes.max(1)[1]).sum()).item()
            utils.reset_net(model)
        for j in range(num_activations):
            sfr_layer_list[j]=[sum(sfr_layer_list[j][:(i+1)])/(i+1) for i in range(len(sfr_layer_list[j]))]
    return np.array(tot)/length ,np.array(sfr_layer_list)/length

def get_firing_rate(model):
    fr=[]
    for name,child in model.named_children():
        if isinstance(child, ScaledNeuron):
            fr.append(child.batch_fire_rate)
        else:
            fr.extend(get_firing_rate(child))
    return fr

def train_ann(train_dataloader, test_dataloader, model, optimizer, scheduler, epochs, device,  save,  activation_mode,work_directory):
    model.to(device)
    para1= utils.regular_set(model)# para1 是 up 值

    #分组的目的是用不同的优化参数；现在优化参数全是一样的，根本没必要分组。这样一来，regular_set只用返回para1计算正则化可以了。

    loss_fn=nn.CrossEntropyLoss()
    
    best_acc = 0
    val_acc_list, val_loss_list, tr_acc_list, tr_loss_list=[], [], [], []

    for epoch in range(epochs):
        model.train()
        total, correct, epoch_loss= 0, 0, 0.
        for img, label in tqdm(train_dataloader,disable=True): #DEBUG
            img = img.to(device)
            label = label.to(device)
            
            out = model(img)
            loss = loss_fn(out, label)

            total+=len(label)
            correct+=(label==out.max(1)[1]).sum().item() 
            epoch_loss+=loss.item()
            
            # 训练Softplus时对 log(e^lambda+1) 计算l2正则化
            if activation_mode=='softplus':
                lambda_l2 = 0.0005
            else:
                lambda_l2 = 0
            l2_regularization = torch.tensor(0.).to(device)
            for para in para1:
                l2_regularization += F.softplus(para).sum()
            loss += lambda_l2 * l2_regularization            
            if loss.cpu().detach().numpy() > 1e8 or np.isnan(loss.cpu().detach().numpy()):
                raise ValueError('Diverged...')
           
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        scheduler.step()
            
        tmp_acc, val_loss = eval_ann(test_dataloader, model, loss_fn, device)
        val_acc_list.append(tmp_acc)
        val_loss_list.append(val_loss)
        tr_acc_list.append(correct/total)
        tr_loss_list.append(epoch_loss/len(train_dataloader))
        print(f'Epoch {epoch} : Val_loss: {val_loss:.5f}, Acc: {tmp_acc*100:.3f}%', flush=True)

        if save != None and tmp_acc >= best_acc:
            best_acc=tmp_acc
            os.makedirs(os.path.join(work_directory,'bestmodel',save),exist_ok=True)
            torch.save(model.state_dict(), os.path.join(work_directory,'bestmodel',save,'best_model.pth'))
        
    return np.array(val_acc_list), np.array(val_loss_list), np.array(tr_acc_list), np.array(tr_loss_list)

def all_params_buffers_on_device(model: nn.Module, device: torch.device) -> bool:
    """
    检查 model 中所有参数和 buffers 是否都在指定 device 上。
    """
    for name, param in model.named_parameters(recurse=True):
        if param.device != device:
            print(f"Parameter {name} is on {param.device}, expected {device}")
 

    for name, buf in model.named_buffers(recurse=True):
        if buf.device != device:
            print(f"Buffer {name} is on {buf.device}, expected {device}")

if __name__=='__main__':
    import Models
    model=Models.modelpool('vgg-16',10)
    model = utils.replace_activation_by_neuron(model)
    fr=get_firing_rate(model)
    assert len(fr)==15
    print(fr)





'''
def mp_test(test_dataloader, model, net_arch, presim_len, sim_len, device):
    new_tot = torch.zeros(sim_len).to(device)
    model = model.to(device)
    model.eval()
    
    with torch.no_grad():
        for img, label in tqdm(test_dataloader):
            new_spikes = 0
            img = img.to(device)
            label = label.to(device)
            
            for t in range(presim_len+sim_len):
                out = model(img)
                
                if t >= presim_len:
                    new_spikes += out
                    new_tot[t-presim_len] += (label==new_spikes.max(1)[1]).sum().item()
                   
    return new_tot
'''