"""
To train and test the model
"""

import os
import argparse
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from spikingjelly.activation_based import functional
import torch.nn.functional as F
from models import SCNN
from DVSLIP import *

np.int = int

_seed_ = 42
import random
random.seed(_seed_)

torch.manual_seed(_seed_)  # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
torch.cuda.manual_seed_all(_seed_)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(_seed_)

if __name__ == "__main__":
    # python main.py -e 100 -b 32 -T 90 -augS -augT --Tnbmask 6 --Tmaxmasklength 18 -opt Adam --front -CUDA_VISIBLE_DEVICES 0
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', dest='filename', default='test', type=str, help='filename to store the model')
    parser.add_argument('-e', dest='epochs', default=100, type=int, help='number of training epochs')
    parser.add_argument('--actreg', default=0.0, type=float, help='activity regularization for SNNs')
    parser.add_argument('--finetune', action='store_true', default=False, help='restart training from the given model')
    parser.add_argument('-T', default=90, type=int, help='nb of frames for data pre-processing')
    parser.add_argument('-b', dest='batch_size', default=32, type=int, help='training batch_size')
    parser.add_argument('-K', default=16, type=int, help='SlidingPSN ')
    parser.add_argument('--downsample', default=88, type=int, help='The resize of image')
    parser.add_argument('--nowarmup', action='store_true', default=False, help='no warmup epoch')
    parser.add_argument('-tet', action='store_true')
    parser.add_argument('-T-train', default=-1, type=int)
    parser.add_argument('-shuffle', action='store_true')
    parser.add_argument('-j', default=4, type=int, metavar='N',
                    help='number of data loading workers (default: 4)')
    parser.add_argument('-aug', action='store_true')
    parser.add_argument('-opt', default='Adam', type=str)
    parser.add_argument('-resume', type=str, help='resume from the checkpoint path')
    parser.add_argument('-augS', action='store_true', default=False, help='spatial data augmentation (for training)')
    parser.add_argument('-augT', action='store_true', default=False, help='temporal data augmentation (for training)')
    parser.add_argument('--Tnbmask', default=6, type=int, help='nb of masks for temporal data augmentation')
    parser.add_argument('--Tmaxmasklength', default=18, type=int, help='maximale length of each mask for temporal data augmentation')
    parser.add_argument('-CUDA_VISIBLE_DEVICES', default=0)
    parser.add_argument('--front', action='store_true')

    args = parser.parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = args.CUDA_VISIBLE_DEVICES
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float
    SAVE_PATH_MODEL_BEST = os.getcwd() + '/' + args.filename + '.pt'

    downsample = args.downsample
    
    train_data_root = '/datasets/DVS-Lip/extract/DVS-Lip/train'
    test_data_root = '/datasets/DVS-Lip/extract/DVS-Lip/test'

    training_words = get_training_words()
    label_dct = {k:i for i,k in enumerate(training_words)}

    train_dataset = DVSLipDataset(train_data_root, label_dct, train=True, augment_spatial=args.augS, augment_temporal=args.augT, T=args.T, Tnbmask=args.Tnbmask, Tmaxmasklength=args.Tmaxmasklength)
    test_dataset = DVSLipDataset(test_data_root, label_dct, train=False, augment_spatial=False, augment_temporal=False, T=args.T)
    train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=2, shuffle=True, pin_memory=True) 
    test_data_loader = DataLoader(test_dataset, batch_size=args.batch_size, num_workers=2, shuffle=False, pin_memory=True) 

    model = SCNN(front=args.front)
    model.to(device)
    print(model)

    out_dir = f'T_{args.T}_e{args.epochs}_b_{args.batch_size}_HW_{args.downsample}_Tnbmask_{args.Tnbmask}_Tmaxmasklength{args.Tmaxmasklength}'

    if args.front:
        out_dir += '_front'
    else:
        out_dir += '_spikGRU2+'

    if args.finetune:
        out_dir += '_finetune'

    pt_dir = os.path.join("logs", 'pt', out_dir)
    if not os.path.exists(pt_dir):
        os.makedirs(pt_dir)

    start_epoch = 0
    writer = SummaryWriter(os.path.join('logs', out_dir), purge_step=start_epoch)

    max_test_acc = -1

    ## TRAINING PARAMETERS
    ########################################################################
    loss_fn = torch.nn.CrossEntropyLoss().to(device)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['net'])

    params = []

    if args.finetune:
        lr = 1e-4 * (args.batch_size / 32)
        final_lr = 5e-6 * (args.batch_size / 32)
    else:
        lr = 3e-4 * (args.batch_size / 32)
        final_lr = lr

    for name, param in model.named_parameters():
        if "bn" in name:
            params += [{'params':param, 'lr':lr}]
        else:
            params += [{'params':param, 'lr': lr, 'weight_decay': 1e-4}]

    if args.opt == 'AdamW':
        optimizer = torch.optim.AdamW(params)
    elif args.opt == 'Adam':
        optimizer = torch.optim.Adam(params)
    else:
        raise ValueError

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=final_lr, last_epoch=-1)

    if args.nowarmup:
        warmup_epochs = 0
    else:
        warmup_epochs = 1

    if warmup_epochs > 0:
        for g in optimizer.param_groups:
            g['lr'] /= len(train_data_loader)*warmup_epochs
        warmup_itr = 1
    best_val = 0

    for epoch in range(start_epoch, args.epochs):
        model.train()

        train_loss = 0
        train_acc = 0
        train_samples = 0
        for ni, (img, label) in enumerate(train_data_loader):
            optimizer.zero_grad()
            img = img.to(device)
            label = label.to(device)

            y = model(img)
            loss = F.cross_entropy(y, label)
            
            loss.backward()
            optimizer.step()

            model.clamp()

            if epoch < warmup_epochs:
                for g in optimizer.param_groups:
                    g['lr'] *= (warmup_itr+1)/(warmup_itr)
                warmup_itr += 1

            train_samples += label.shape[0]
            train_loss += loss.item() * label.shape[0]

            train_acc += (y.argmax(1) == label).float().sum().item()
            functional.reset_net(model)

        train_acc /= train_samples
        train_loss /= train_samples


        writer.add_scalar("train_acc", train_acc, epoch)
        writer.add_scalar("train_loss", train_loss, epoch)
        scheduler.step()

        model.eval()
        test_loss = 0
        test_acc = 0
        test_samples = 0

        with torch.no_grad():
            for img, label in test_data_loader:
                img = img.to(device)
                label = label.to(device)
                y = model(img)

                loss = F.cross_entropy(y, label)
                test_samples += label.numel()
                test_loss += loss.item() * label.numel()
                test_acc += (y.argmax(1) == label).float().sum().item()
                functional.reset_net(model)
        test_loss /= test_samples
        test_acc /= test_samples
        writer.add_scalar('test_loss', test_loss, epoch)
        writer.add_scalar('test_acc', test_acc, epoch)
        print(f'epoch = {epoch}, train_acc = {train_acc}, train_loss = {train_loss}, test_acc = {test_acc}, test_loss = {test_loss}')

        save_max = False
        if test_acc > max_test_acc:
            max_test_acc = test_acc
            save_max = True


        checkpoint = {
            'net': model.state_dict(),
        }
        if save_max:
            torch.save(checkpoint, os.path.join(pt_dir, 'checkpoint_max.pth'))

        torch.save(checkpoint, os.path.join(pt_dir, 'checkpoint_latest.pth'))