import datetime
import os
import time
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.cuda import amp
from search import super_model
from search.architecture import Architecture
import argparse
from spikingjelly.clock_driven import functional
from spikingjelly.clock_driven import surrogate as surrogate_sj
from utils import Bar, Logger, AverageMeter, accuracy
import torch.utils.data as data
from utils.ShuffleSampler import *
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchtoolbox.transform import Cutout
from utils.cifar10_dvs import CIFAR10DVS, ToPILImage, Resize, ToTensor
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
import collections
import random
import numpy as np
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

'''         

        This script is used to search for the optimal neuron configuration.
        
        Get started through:
        
        python train_search.py -dataset 'cifar10' -model 'super_resnet18' -b 128 -T 6;

'''


def main():
    parser = argparse.ArgumentParser(description='Searching')
    parser.add_argument('-seed', default=2024, type=int, help='hope you have good luck in the year of 2024 :D')
    parser.add_argument('-name', default='', type=str, help='')
    parser.add_argument('-T', default=6, type=int, help='time step setting')
    parser.add_argument('-tau', default=1.1, type=float, help='membrane time constant')
    parser.add_argument('-epochs', default=100, type=int, metavar='N', help='epoch')
    parser.add_argument('-j', default=4, type=int, metavar='N', help='number of loading workers')
    parser.add_argument('-b', default=128, type=int, help='batchsize')
    parser.add_argument('-data_dir', type=str, default='./data', help='directory of the dataset')
    parser.add_argument('-dataset', default='cifar10', type=str,help='name of the dataset, cifar10, cifar100, tinyimagenet dvscifar10, or dvsgesture')
    parser.add_argument('-model', type=str, default='super_resnet18', help='network model')
    parser.add_argument('-out_dir', type=str, default='./logs', help='root dir for logs and checkpoint')
    parser.add_argument('-surrogate', default='triangle', type=str, help='surrogate function')
    parser.add_argument('-resume', type=str, help='resume path')
    parser.add_argument('-amp', action='store_false', help='automatic mixed precision training')
    parser.add_argument('-opt', type=str, help='optimizer SGD or AdamW', default='SGD')
    parser.add_argument('-w_lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('-momentum', default=0.9, type=float, help='momentum for SGD')
    parser.add_argument('-w_lr_scheduler', default='CosALR', type=str, help='scheduler')
    parser.add_argument('-step_size', default=300, type=float, help='step size for StepLR')
    parser.add_argument('-gamma', default=0.1, type=float, help='gamma for StepLR')
    parser.add_argument('-w_T_max', default=300, type=int, help='w_T_max for CosineAnnealingLR')
    parser.add_argument('-w_weight_decay', type=float, default=5e-5)
    parser.add_argument('-alpha_lr', default=3e-4, type=float, help='learning rate for architecture parameters')
    parser.add_argument('-alpha_weight_decay', type=float, default=1e-4, help='weight decay for architecture parameters')
    parser.add_argument('-drop_rate', type=float, default=0.0, help='dropout rate')
    parser.add_argument('-loss_lambda', type=float, default=0.05, help='constant for TET loss')
    parser.add_argument('-save_init', action='store_true', help='save the initialization of parameters')
    args = parser.parse_args()
    print(args)
    _seed_ = args.seed
    random.seed(_seed_)
    torch.manual_seed(_seed_)
    torch.cuda.manual_seed_all(_seed_)
    np.random.seed(_seed_)

    # data processing & LWNS-net construct
    data_dir = args.data_dir
    if args.dataset == 'cifar10' or args.dataset == 'cifar100':
        c_in = 3
        if args.dataset == 'cifar10':
            dataloader = datasets.CIFAR10
            num_classes = 10
            normalization_mean = (0.4914, 0.4822, 0.4465)
            normalization_std = (0.2023, 0.1994, 0.2010)
        elif args.dataset == 'cifar100':
            dataloader = datasets.CIFAR100
            num_classes = 100
            normalization_mean = (0.5071, 0.4867, 0.4408)
            normalization_std = (0.2675, 0.2565, 0.2761)
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            Cutout(),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(normalization_mean, normalization_std),
        ])
        trainset = dataloader(root=data_dir, train=True, download=True, transform=transform_train)
    elif args.dataset == 'dvscifar10':
        c_in = 2
        num_classes = 10
        data_dir = os.path.join(data_dir, 'dvscifar10')
        transform_train = transforms.Compose([
            ToPILImage(),
            Resize(48),
            ToTensor(),
        ])
        transform_test = transforms.Compose([
            ToPILImage(),
            Resize(48),
            ToTensor(),
        ])
        trainset = CIFAR10DVS(data_dir, train=True, use_frame=True, frames_num=args.T, split_by='number', normalization=None, transform=transform_train)
    elif args.dataset == 'dvsgesture':
        c_in = 2
        num_classes = 11
        data_dir = os.path.join(data_dir, 'dvsgesture')
        trainset = DVS128Gesture(root=data_dir, train=True, data_type='frame', frames_number=args.T, split_by='number')
    elif args.dataset == 'imagenet':
        c_in = 3
        data_dir = os.path.join(data_dir, 'imagenet')
        num_classes = 1000
        traindir = os.path.join(data_dir, 'train')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])
        transform_train = transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        trainset = datasets.ImageFolder(traindir, transform_train)
    elif args.dataset == 'tinyimagenet':
        c_in = 3
        data_dir = os.path.join(data_dir, 'tiny-imagenet-200')
        num_classes = 200
        traindir = os.path.join(data_dir, 'train')
        normalize = transforms.Normalize(mean=[0.4802, 0.4481, 0.3975],
                                         std=[0.2770, 0.2691, 0.2821])
        transform_train = transforms.Compose([
            transforms.RandomCrop(64, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ])
        trainset = datasets.ImageFolder(traindir, transform_train)
    else:
        raise NotImplementedError
    if args.dataset == 'cifar10' or args.dataset == 'cifar100' or args.dataset == 'tinyimagenet':
        super_net = super_model.__dict__[args.model](num_classes=num_classes, neuron_dropout=args.drop_rate, tau=args.tau, c_in=c_in)

    elif args.dataset == 'dvscifar10' or args.dataset == 'dvsgesture':
        super_net = super_model.__dict__[args.model](num_classes=num_classes, neuron_dropout=args.drop_rate, tau=args.tau, c_in=c_in)

    else:
        raise NotImplementedError
    super_net.cuda()

    # optimizer
    if args.opt == 'SGD':
        w_optimizer = torch.optim.SGD(super_net.parameters(), lr=args.w_lr, momentum=args.momentum, weight_decay=args.w_weight_decay)
    elif args.opt == 'AdamW':
        w_optimizer = torch.optim.AdamW(super_net.parameters(), lr=args.w_lr, weight_decay=args.w_weight_decay)
    else:
        raise NotImplementedError(args.opt)
    alpha_optim = torch.optim.Adam(super_net.alphas(), lr=args.alpha_lr, betas=(0.5, 0.999),
                                   weight_decay=args.alpha_weight_decay)
    if args.w_lr_scheduler == 'StepLR':
        w_lr_scheduler = torch.optim.lr_scheduler.StepLR(w_optimizer, step_size=args.step_size, gamma=args.gamma)
    elif args.w_lr_scheduler == 'CosALR':
        w_lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(w_optimizer, T_max=args.w_T_max)
    else:
        raise NotImplementedError(args.w_lr_scheduler)
    scaler = None
    if args.amp:
        scaler = amp.GradScaler()

    # architecture creation
    archi = Architecture(super_net, w_momentum=args.momentum, w_weight_decay=args.w_weight_decay)
    start_epoch = 0

    # resume
    if args.resume:
        print('resuming...')
        checkpoint = torch.load(args.resume, map_location='cpu')
        super_net.load_state_dict(checkpoint['super_net'])
        w_optimizer.load_state_dict(checkpoint['w_optimizer'])
        alpha_optim.load_state_dict(checkpoint['alpha_optim'])
        w_lr_scheduler.load_state_dict(checkpoint['w_lr_scheduler'])
        start_epoch = checkpoint['epoch'] + 1
        print('start epoch:', start_epoch)

    # output setting
    out_dir = os.path.join(args.out_dir, f'Search_{args.dataset}_{args.model}_T{args.T}_tau{args.tau}_e{args.epochs}_bs{args.b}_alr{args.alpha_lr}_awd{args.alpha_weight_decay}_drop{args.drop_rate}')

    if args.w_lr_scheduler == 'CosALR':
        out_dir += f'CosALR_{args.w_T_max}'
    elif args.w_lr_scheduler == 'StepLR':
        out_dir += f'StepLR_{args.step_size}_{args.gamma}'
    else:
        raise NotImplementedError(args.w_lr_scheduler)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        print(f'Mkdir {out_dir}.')
    else:
        print('out dir already exists:', out_dir)

    # save the initialization of parameters
    if args.save_init:
        checkpoint = {
            'super_net': super_net.state_dict(),
            'epoch': 0,
            'max_test_acc': 0.0
        }
        torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_0.pth'))
    with open(os.path.join(out_dir, 'args.txt'), 'w', encoding='utf-8') as args_txt:
        args_txt.write(str(args))
    writer = SummaryWriter(os.path.join(out_dir, 'logs'), purge_step=start_epoch)

    # searching
    criterion_mse = nn.MSELoss()
    for epoch in range(start_epoch, args.epochs):
        start_time = time.time()
        super_net.train()
        batch_time = AverageMeter()
        data_time = AverageMeter()
        losses = AverageMeter()
        top1 = AverageMeter()
        top5 = AverageMeter()
        end = time.time()
        train_loss = 0
        train_acc = 0
        train_samples = 0
        lr = args.w_lr
        samplerA = SamplerA(trainset)
        samplerB = SamplerB(trainset, samplerA.samplerA_indices)
        A = DataLoader(trainset, batch_size=args.b, sampler=samplerA, num_workers=args.j, pin_memory=True, drop_last=True)
        B = DataLoader(trainset, batch_size=args.b, sampler=samplerB, num_workers=args.j, pin_memory=True, drop_last=True)
        bar = Bar('Searching', max=len(A))
        for batch_idx, ((trn_X, trn_Y), (val_X, val_Y)) in enumerate(zip(A, B)):
            t_step = args.T
            if args.dataset != 'dvscifar10':
                trn_X = trn_X.float().cuda()
                val_X = val_X.float().cuda()
                if args.dataset == 'dvsgesture':
                    trn_X = trn_X.transpose(0, 1)
                    val_X = val_X.transpose(0, 1)
            else:
                trn_X = torch.stack(trn_X, dim=0).float().cuda()
                val_X = torch.stack(val_X, dim=0).float().cuda()
            trn_Y = trn_Y.cuda()
            val_Y = val_Y.cuda()
            batch_loss = 0

            # stage 2. architecture step
            alpha_optim.zero_grad()
            archi.unrolled_backward(trn_X, trn_Y, val_X, val_Y, lr, w_optimizer)
            alpha_optim.step()

            # stage 1. network step
            w_optimizer.zero_grad()
            for t in range(t_step):
                if args.dataset == 'dvscifar10':
                    input = trn_X[t].float().cuda()
                elif args.dataset == 'dvsgesture':
                    input = trn_X[t]
                else:
                    input = trn_X
                if args.amp:
                    with amp.autocast():
                        if t == 0:
                            output = super_net(input)
                            total_output = output.clone().detach()
                        else:
                            output = super_net(input)
                            total_output += output.clone().detach()
                    loss = F.cross_entropy(output, trn_Y) / t_step
                    scaler.scale(loss).backward()
                    batch_loss += loss.item()
                    train_loss += loss.item() * trn_Y.numel()
                else:
                    raise NotImplementedError('Please use amp.')
            if args.amp:
                scaler.step(w_optimizer)
                scaler.update()
            else:
                w_optimizer.step()

            # accuracy and loss
            prec1, prec5 = accuracy(total_output.data, trn_Y.data, topk=(1, 5))
            losses.update(batch_loss, input.size(0))
            top1.update(prec1.item(), input.size(0))
            top5.update(prec5.item(), input.size(0))
            train_samples += trn_Y.numel()
            train_acc += (total_output.argmax(1) == trn_Y).float().sum().item()
            functional.reset_net(super_net)

            # time
            batch_time.update(time.time() - end)
            end = time.time()

            # plot progress
            bar.suffix  = '({batch}/{size}) Data: {data:.3f}s | Batch: {bt:.3f}s | Total: {total:} | ETA: {eta:} | Loss: {loss:.4f} | top1: {top1: .4f} | top5: {top5: .4f}'.format(
                        batch=batch_idx,
                        size=len(A),
                        data=data_time.avg,
                        bt=batch_time.avg,
                        total=bar.elapsed_td,
                        eta=bar.eta_td,
                        loss=losses.avg,
                        top1=top1.avg,
                        top5=top5.avg,
                        )
            bar.next()
        bar.finish()
        train_loss /= train_samples
        train_acc /= train_samples
        writer.add_scalar('train_loss', train_loss, epoch)
        writer.add_scalar('train_acc', train_acc, epoch)
        w_lr_scheduler.step()
        lr = w_lr_scheduler.get_last_lr()[0]
        checkpoint = {
            'super_net': super_net.state_dict(),
            'w_optimizer': w_optimizer.state_dict(),
            'alpha_optim': alpha_optim.state_dict(),
            'w_lr_scheduler': w_lr_scheduler.state_dict(),
            'epoch': epoch,
        }
        torch.save(checkpoint, os.path.join(out_dir, 'checkpoint_latest.pth'))
        total_time = time.time() - start_time
        print(f'epoch={epoch}, train_loss={train_loss}, train_acc={train_acc}, total_time={total_time}, escape_time={(datetime.datetime.now()+datetime.timedelta(seconds=total_time * (args.epochs - epoch))).strftime("%Y-%m-%d %H:%M:%S")}')
        print("after one epoch: %fGB" % (torch.cuda.max_memory_cached(0) / 1024 / 1024 / 1024))
        optimal_neuron = super_net.alphas()[0].argmax(dim=1)
        print('current optimal architecture:', optimal_neuron)

if __name__ == '__main__':
    main()
