from datasets import SHD_dataloaders
from config import Config
from snn_delays import SnnDelays
import torch
from snn import SNN
import utils
import argparse

parser = argparse.ArgumentParser(description='Classify SHD')

parser.add_argument('-K', type=int, default=2, help='the kernel size of the sliding PSN')
parser.add_argument('-seed', type=int, default=0, help='the random seed')
parser.add_argument('-device', default='cuda:0', help='device')
parser.add_argument('-model_type', default='snn_delays', help='model type')
parser.add_argument('-n_hidden_layers', type=int, default=2, help='number of hidden layers')
parser.add_argument('-n_hidden_neurons', type=int, default=256, help='number of hidden neurons')
parser.add_argument('-epochs', type=int, default=150, help='number of epochs')
parser.add_argument('-batch_size', type=int, default=128, help='batch size')
parser.add_argument('-lr_w', type=float, default=5e-4, help='learning rate for weights')
parser.add_argument('-lr_gpsn', type=float, default=5e-4, help='learning rate for psn')
parser.add_argument('-loss', default='sum', help='loss function')
parser.add_argument('-outdir', default='logs', help='output directory')
parser.add_argument('-weight_decay', type=float, default=1e-5, help='weight decay')

args = parser.parse_args()

device = args.device
print(f"\n=====> Device = {device} \n\n")

config = Config()
config.K = args.K
config.seed = args.seed
config.model_type = args.model_type
config.n_hidden_layers = args.n_hidden_layers
config.epochs = args.epochs
config.batch_size = args.batch_size
config.n_hidden_neurons = args.n_hidden_neurons
config.weight_decay = args.weight_decay

config.lr_w = args.lr_w
config.lr_gpsn = args.lr_gpsn
config.lr_pos = 1e-1  if config.model_type =='snn_delays' else 0


config.max_lr_w = 5 * config.lr_w
config.max_lr_pos = 5 * config.lr_pos
config.max_lr_gpsn = 5 * config.lr_gpsn


config.loss = args.loss

config.save_model_path = f'{args.outdir}/k{config.K}_aug{config.augment}_seed{config.seed}_lrw_{config.lr_w}_lrgpsn_{config.lr_gpsn}_batch_size_{config.batch_size}epochs_{config.epochs}_REPL.pt'

utils.set_seed(config.seed)


if config.model_type == 'snn':
    model = SNN(config).to(device)
else:
    model = SnnDelays(config).to(device)

print(model)
if config.model_type == 'snn_delays_lr0':
    model.round_pos()
    
print(f"===> Dataset    = {config.dataset}")
print(f"===> Model type = {config.model_type}")
print(f"===> Model size = {utils.count_parameters(model)}\n\n")


if config.dataset == 'shd':
    train_loader, valid_loader = SHD_dataloaders(config)
    test_loader = None
else:
    raise Exception(f'dataset {config.dataset} not implemented')


model.train_model(train_loader, valid_loader, test_loader, device, args)