

import argparse
import logging
import os
import time
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from spikingjelly.clock_driven import functional
#from WideResNet import *
from torch.optim.lr_scheduler import StepLR
from torch.autograd import Variable
from tqdm import tqdm
import torchvision
#from pymodel import *

#from snnwrntea import *
#from snnresnet66 import *

#from spiking_resnet import *
#from resnet import *
#from snn_wrn import *
from netnew import *
import utils

#import data_loader as data_loader

from evaluate import *
from torchvision import datasets, transforms
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
parser = argparse.ArgumentParser()




from tensorboardX import SummaryWriter

writer = SummaryWriter('./kdnet1try/')


def train(model, optimizer, loss_fn, dataloader, metrics, params):


    model.train()

    summ = []
    loss_avg = utils.RunningAverage()

    with tqdm(total=len(dataloader)) as t:
        for i, (train_batch, labels_batch) in enumerate(dataloader):

            if params.device:
                train_batch, labels_batch = train_batch.to(device,dtype=torch.float), labels_batch.to(device,dtype=torch.float)

            train_batch, labels_batch = Variable(train_batch), Variable(labels_batch)


            output_batch = model(train_batch)
            #loss = loss_fn(output_batch, labels_batch.long())
            loss = F.mse_loss(output_batch, F.one_hot(labels_batch.long(), 11).float())


            optimizer.zero_grad()
            loss.backward()


            optimizer.step()
            functional.reset_net(model)


            if i % params.save_summary_steps == 0:

                output_batch = output_batch.data.cpu().numpy()
                labels_batch = labels_batch.data.cpu().numpy()


                summary_batch = {metric: metrics[metric](output_batch, labels_batch)
                                 for metric in metrics}
                summary_batch['loss'] = loss.item()
                summ.append(summary_batch)


            loss_avg.update(loss.item())

            t.set_postfix(loss='{:05.3f}'.format(loss_avg()))
            t.update()


    metrics_mean = {metric: np.mean([x[metric] for x in summ]) for metric in summ[0]}
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items())
    logging.info("- Train metrics: " + metrics_string)


def train_and_evaluate(model, train_dataloader, val_dataloader, optimizer,
                       loss_fn,  metrics,params, model_dir, restore_file=None):

    if restore_file is not None:
        restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)

    best_val_acc = 0.0



    #scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=64)


    writer.flush()
    max=0.0

    for epoch in range(params.num_epochs):
        print(epoch)
        scheduler.step()


        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))


        train(model, optimizer, loss_fn, train_dataloader, metrics, params)


        val_acc = evaluate(model, loss_fn, val_dataloader,  metrics,params)

        
        print(val_acc)

        if val_acc > max:
            max = val_acc
            torch.save({
                
                'state_dict': model.state_dict(),
                
            }, os.path.join('./', 'bestnetdvs.pth.tar'))
            #print('saved') 
        print(max)

        writer.add_scalar('test_accuracy', val_acc, epoch)




def train_kd(model, teacher_model,optimizer,  dataloader, metrics, params):

    model.train()

    summ = []
    loss_avg = utils.RunningAverage()

    for train_batch, labels_batch in tqdm(dataloader):


        train_batch, labels_batch = train_batch.to(device,dtype=torch.float), labels_batch.to(device,dtype=torch.float)

        train_batch, labels_batch = Variable(train_batch), Variable(labels_batch)
        #tt=train_batch
        #print(tt.shape)
        #tt1=train_batch
        #print(tt1.shape)
        #print(train_batch.shape)
        
        #print(train_batch.shape)
        with torch.no_grad():
            #print(tt1.shape)
            output_teacher_batch = teacher_model(train_batch)
            output_teacher_batch=output_teacher_batch/16
        #print(train_batch.shape)
        output_batch = model(train_batch).to(device)
        output_batch=output_batch/16


        




        alpha = params.alpha
        T = params.temperature
        #print(nn.KLDivLoss()(F.log_softmax(output_batch / T, dim=1),
                              #F.softmax(output_teacher_batch / T, dim=1)) * (alpha * T * T) )
        #print(F.cross_entropy(output_batch, labels_batch.long()) * (1. - alpha))
        

        loss = nn.KLDivLoss()(F.log_softmax(output_batch / T, dim=1),
                              F.softmax(output_teacher_batch / T, dim=1)) * (alpha * T * T)*1000 + F.cross_entropy(output_batch, labels_batch.long()) * (1. - alpha)
        #loss = F.cross_entropy(output_batch, labels_batch)

        

        optimizer.zero_grad()
        loss.backward()
        '''#-----------------------------------------
        for k, m in enumerate(model.modules()):
            # print(k, m)
            if isinstance(m, nn.Conv2d):
                weight_copy = m.weight.data.abs().clone()
                mask = weight_copy.gt(0).float().to(device)
                m.weight.grad.data.mul_(mask)
                #print('layer index: {:d} \t total params: {:d} \t remaining params: {:d}'.
                #format(k, mask.numel(), int(torch.sum(mask))))
        #-----------------------------------------'''
        optimizer.step()
        functional.reset_net(model)
   
# -------------------------------------------------------------
    '''#pruning 
    total = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            total += m.weight.data.numel()
    conv_weights = torch.zeros(total)
    index = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            size = m.weight.data.numel()
            conv_weights[index:(index+size)] = m.weight.data.view(-1).abs().clone()
            index += size

    y, i = torch.sort(conv_weights)
    thre_index = int(total * 0.3)
    thre = y[thre_index]
    pruned = 0
    print('Pruning threshold: {}'.format(thre))
    zero_flag = False
    for k, m in enumerate(model.modules()):
        if isinstance(m, nn.Conv2d):
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.gt(thre).float().to(device)
            pruned = pruned + mask.numel() - torch.sum(mask)
            m.weight.data.mul_(mask)
            if int(torch.sum(mask)) == 0:
                zero_flag = True
            print('layer index: {:d} \t total params: {:d} \t remaining params: {:d}'.
                format(k, mask.numel(), int(torch.sum(mask))))
    print('Total conv params: {}, Pruned conv params: {}, Pruned ratio: {}'.format(total, pruned, pruned/total))'''
# -------------------------------------------------------------

       
    
       



def train_and_evaluate_kd(model,teacher_model, train_dataloader, val_dataloader, optimizer,
                           metrics, params, model_dir, restore_file=None):

    if restore_file is not None:
        restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))


    best_val_acc = 0.0



    #scheduler = StepLR(optimizer, step_size=150, gamma=0.1)

    #scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=64)
    writer.flush()
    
    val_acc = evaluate_kd(teacher_model, val_dataloader, metrics, params)

        
    print(val_acc)

    max=0.0
    for epoch in range(params.num_epochs):
        #scheduler.step()
        print(epoch)


        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))


        train_kd(model,teacher_model, optimizer,  train_dataloader,
                 metrics, params)
        scheduler.step()

        ta = evaluate_kd(model, val_dataloader, metrics, params)
        writer.add_scalar('test_accuracy', ta, epoch)
        print(epoch+1)
        print(ta)
        #print(ta)
        if ta> max:
            max = ta
        print(max)







def accuracy(outputs, labels):

    outputs = np.argmax(outputs, axis=1)
    return np.sum(outputs==labels)/float(labels.size)

def parse_args():
    parser.add_argument('--model_version', default='net')
    parser.add_argument('--subset_percent', default=1.0, type=float)
    parser.add_argument('--augmentation', default="yes")
    parser.add_argument('--teacher', default="net")
    parser.add_argument('--alpha', default=0.5, type=float)
    parser.add_argument('--temperature', default=6, type=int)
    parser.add_argument('--learning_rate', default=1e-3, type=float)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--num_epochs', default=300, type=int)
    parser.add_argument('--dropout_rate', default=0.0, type=float)
    parser.add_argument('--num_channels', default=32, type=int)
    parser.add_argument('--save_summary_steps', default=100, type=int)
    parser.add_argument('--num_workers', default=1, type=int)
    parser.add_argument('--model_dir', default='./',
                        help="Directory containing params.json")
    parser.add_argument('--restore_file', default=None,
                        help="Optional, name of the file in --model_dir \
                        containing weights to reload before training")
    parser.add_argument('--device', default='cuda:5')
    args = parser.parse_args()
    return args


if __name__ == '__main__':

    device = 'cuda:5'


    params = parse_args()


    random.seed(230)
    torch.manual_seed(230)
    if params.device: torch.cuda.manual_seed(230)




    logging.info("Loading the datasets...")


    '''if params.subset_percent < 1.0:
        train_dl = data_loader.fetch_subset_dataloader('train', params)
    else:
        train_dl = data_loader.fetch_dataloader('train', params)
        
    dev_dl = data_loader.fetch_dataloader('dev', params)'''
    '''mean = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
    std = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    transform_test = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    cifar100_training = torchvision.datasets.CIFAR100(root='./',
                                                      train=True, download=True,
                                                      transform=transform_train)
    train_dl = torch.utils.data.DataLoader(cifar100_training, batch_size=64, shuffle=True, drop_last=True)

    cifar100_testing = torchvision.datasets.CIFAR100(root='./',
                                                     train=False, download=True,
                                                     transform=transform_test)
    dev_dl = torch.utils.data.DataLoader(cifar100_testing, batch_size=64, shuffle=False, drop_last=True)'''

    train_set = DVS128Gesture(root='./gdata', train=True, data_type='frame', frames_number=16, split_by='number')
    test_set = DVS128Gesture(root='./gdata', train=False, data_type='frame', frames_number=16, split_by='number')

    train_dl= torch.utils.data.DataLoader(
        dataset=train_set,
        batch_size=8,
        shuffle=True,
        num_workers=1,
        drop_last=True
    )

    dev_dl = torch.utils.data.DataLoader(
        dataset=test_set,
        batch_size=4,
        shuffle=False,
        num_workers=1,
        drop_last=False
    )

    '''mean = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
    std = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),  
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    transform_test = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    cifar100_training = torchvision.datasets.CIFAR100(root='/media/omnisky/Data/spiking/chonggou/prune/cifar100vgg16/', train=True, download=True,
                                                      transform=transform_train)
    train_dl  = torch.utils.data.DataLoader(cifar100_training, batch_size=64, shuffle=True,drop_last=True)


    cifar100_testing = torchvision.datasets.CIFAR100(root='/media/omnisky/Data/spiking/chonggou/prune/cifar100vgg16/', train=False, download=True,
                                                     transform=transform_test)
    dev_dl = torch.utils.data.DataLoader(cifar100_testing, batch_size=64, shuffle=False,drop_last=True)'''


    logging.info("- done.")

    if "distill" in params.model_version:

        if params.model_version == "snn_resnet18_distill":  # kd stu resnet
            model = spiking_resnet18().to(device)

            optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0)

            metrics = {'accuracy': accuracy,
                       # could add more metrics such as accuracy for each token type
                       }



        elif params.model_version == 'snn_wrn16-4_distill':
            model = SNNWideResNet().to(device)
            print('snn_wrn16-2_distill')
            optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)

            print(params.learning_rate)

            metrics = {'accuracy': accuracy,
                       # could add more metrics such as accuracy for each token type
                       }
        elif params.model_version == "snn_net_distill":  # kd stu resnet
            model = DVS128GestureNet().to(device)
            print('resnet18')

            # lr = 0.0001
            # print(lr)
            print('ssssssssssssonlyres')
            optimizer = optim.Adam(model.parameters(), lr=params.learning_rate, weight_decay=0)
            # optimizer = optim.SGD(model.parameters(), lr=params.learning_rate,
            # momentum=0.9, weight_decay=0)

            metrics = {'accuracy': accuracy,
                       # could add more metrics such as accuracy for each token type
                       }

        if params.teacher == "snnresnet18":
            teacher_model = spiking_resnet18().to(device)
            teacher_model.to(device)

            teacher_checkpoint = torch.load('./pruned.pth.tar', map_location='cpu')
            teacher_model.load_state_dict(teacher_checkpoint['state_dict'])
            # ta=evaluate_kd(teacher_model, dev_dl, metrics, params)
            # print(ta)
            print('x')
        elif params.teacher == "snnwrn16-4":
            teacher_model = SNNWideResNet().to(device)
            # print("b")

            # teacher_checkpoint = torch.load('./WRN.pt')
            # teacher_model.load_state_dict(teacher_checkpoint)

            teacher_checkpoint = torch.load('./pruned.pth.tar', map_location='cpu')
            teacher_model.load_state_dict(teacher_checkpoint['state_dict'])

        elif params.teacher == "net":
            teacher_model = DVS128GestureNet().to(device)
            # print("b")

            # teacher_checkpoint = torch.load('./WRN.pt')
            # teacher_model.load_state_dict(teacher_checkpoint)

            teacher_checkpoint = torch.load('./netdvsprune300/pruneddvsnet1.pth.tar', map_location='cpu')  # 教师网络模型
            teacher_model.load_state_dict(teacher_checkpoint['state_dict'])

        # Train the model with KD
        logging.info("Experiment - model version: {}".format(params.model_version))
        logging.info("Starting training for {} epoch(s)".format(params.num_epochs))
        logging.info("First, loading the teacher model and computing its outputs...")
        train_and_evaluate_kd(model, teacher_model, train_dl, dev_dl, optimizer,
                              metrics, params, params.model_dir, params.restore_file)


    else:
        if params.model_version == "snnresnet18":
            model = spiking_resnet18().to(device)
            print('snnresnet18')

            optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0)

            loss_fn = nn.CrossEntropyLoss()
            metrics = {'accuracy': accuracy,
                       # could add more metrics such as accuracy for each token type
                       }


        elif params.model_version == "snnvgg16":
            model = snnvgg16_bn().to(device)
            print('snn_vgg16')
            optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)

            loss_fn = nn.CrossEntropyLoss()
            metrics = {'accuracy': accuracy,
                       # could add more metrics such as accuracy for each token type
                       }

        elif params.model_version == "snnwrn16-4":
            model = SNNWideResNet().to(device)
            print('wrn16-4')
            optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)

            # print(params.learning_rate)

            loss_fn = nn.CrossEntropyLoss()
            metrics = {'accuracy': accuracy,
                       # could add more metrics such as accuracy for each token type
                       }
        elif params.model_version == "net":
            model = models.DVS128GestureNet().to(device)
            # optimizer = optim.SGD(model.parameters(), lr=params.learning_rate,
            # momentum=0.9, weight_decay=0)

            optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
            # optimizer = optim.SGD(model.parameters(), lr=params.learning_rate,
            # momentum=0.9, weight_decay=0)

            loss_fn = nn.CrossEntropyLoss()
            metrics = {'accuracy': accuracy,
                       # could add more metrics such as accuracy for each token type
                       }

        # Train the model
        logging.info("Starting training for {} epoch(s)".format(params.num_epochs))
        train_and_evaluate(model, train_dl, dev_dl, optimizer, loss_fn, metrics, params,
                           params.model_dir, params.restore_file)
