

import argparse
import logging
import os
import time
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from spikingjelly.clock_driven import functional
#from WideResNet import *
from torch.optim.lr_scheduler import StepLR
from torch.autograd import Variable
from tqdm import tqdm
import torchvision
#from pymodel import *
from spiking_resnet import *

#from snnresnet66 import *

#from spiking_resnet import *
#from resnet import *
from snnwrn import *
from snnvgg import *
#from WideResNet import *
import utils

#import data_loader as data_loader

from evaluate import *
from torchvision import datasets, transforms

parser = argparse.ArgumentParser()




from tensorboardX import SummaryWriter

writer = SummaryWriter('./')


def train(model, optimizer, loss_fn, dataloader, metrics, params):


    model.train()

    summ = []
    loss_avg = utils.RunningAverage()

    with tqdm(total=len(dataloader)) as t:
        for i, (train_batch, labels_batch) in enumerate(dataloader):

            if params.device:
                train_batch, labels_batch = train_batch.to(device), labels_batch.to(device)

            train_batch, labels_batch = Variable(train_batch), Variable(labels_batch)


            output_batch = model(train_batch)
            loss = loss_fn(output_batch, labels_batch)


            optimizer.zero_grad()
            loss.backward()


            optimizer.step()
            functional.reset_net(model)


            if i % params.save_summary_steps == 0:

                output_batch = output_batch.data.cpu().numpy()
                labels_batch = labels_batch.data.cpu().numpy()


                summary_batch = {metric: metrics[metric](output_batch, labels_batch)
                                 for metric in metrics}
                summary_batch['loss'] = loss.item()
                summ.append(summary_batch)


            loss_avg.update(loss.item())

            t.set_postfix(loss='{:05.3f}'.format(loss_avg()))
            t.update()


    metrics_mean = {metric: np.mean([x[metric] for x in summ]) for metric in summ[0]}
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items())
    logging.info("- Train metrics: " + metrics_string)


def train_and_evaluate(model, train_dataloader, val_dataloader, optimizer,
                       loss_fn,  metrics,params, model_dir, restore_file=None):

    if restore_file is not None:
        restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)

    best_val_acc = 0.0



    scheduler = StepLR(optimizer, step_size=100, gamma=0.1)


    writer.flush()
    max=0.0

    for epoch in range(params.num_epochs):
        scheduler.step()


        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))


        train(model, optimizer, loss_fn, train_dataloader, metrics, params)


        val_acc = evaluate(model, loss_fn, val_dataloader,  metrics,params)

        
        print(val_acc)

        if val_acc > max:
            max = val_acc
            torch.save({
                
                'state_dict': model.state_dict(),
                
            }, os.path.join('./', 'best.pth.tar'))
            print('saved')
        print(max)

        writer.add_scalar('test_accuracy', val_acc, epoch)




def train_kd(model, teacher_model,optimizer,  dataloader, metrics, params):

    model.train()

    summ = []
    loss_avg = utils.RunningAverage()

    for train_batch, labels_batch in tqdm(dataloader):


        train_batch, labels_batch = train_batch.to(device), labels_batch.to(device)

        train_batch, labels_batch = Variable(train_batch), Variable(labels_batch)
        #tt=train_batch
        #print(tt.shape)
        #tt1=train_batch
        #print(tt1.shape)
        #print(train_batch.shape)
        
        #print(train_batch.shape)
        with torch.no_grad():
            #print(tt1.shape)
            output_teacher_batch = teacher_model(train_batch)
        #print(train_batch.shape)
        output_batch = model(train_batch).to(device)



        




        alpha = params.alpha
        T = params.temperature
        #print(nn.KLDivLoss()(F.log_softmax(output_batch / T, dim=1),
                              #F.softmax(output_teacher_batch / T, dim=1)) * (alpha * T * T) )
        #print(F.cross_entropy(output_batch, labels_batch) * (1. - alpha))
        

        loss = nn.KLDivLoss()(F.log_softmax(output_batch / T, dim=1),
                              F.softmax(output_teacher_batch / T, dim=1)) * (alpha * T * T) + \
               F.cross_entropy(output_batch, labels_batch) * (1. - alpha)
        #loss = F.cross_entropy(output_batch, labels_batch)

        

        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
        functional.reset_net(model)


       
    
       



def train_and_evaluate_kd(model,teacher_model, train_dataloader, val_dataloader, optimizer,
                           metrics, params, model_dir, restore_file=None):

    if restore_file is not None:
        restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))


    best_val_acc = 0.0





    scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
    writer.flush()

    max=0.0
    for epoch in range(params.num_epochs):
        #scheduler.step()


        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))


        train_kd(model,teacher_model, optimizer,  train_dataloader,
                 metrics, params)
        scheduler.step()

        ta = evaluate_kd(model, val_dataloader, metrics, params)
        writer.add_scalar('test_accuracy', ta, epoch)
        print(ta)
        #print(ta)
        if ta> max:
            max = ta
        print(max)







def accuracy(outputs, labels):

    outputs = np.argmax(outputs, axis=1)
    return np.sum(outputs==labels)/float(labels.size)

def parse_args():
    parser.add_argument('--model_version', default='resnet18')
    parser.add_argument('--subset_percent', default=1.0, type=float)
    parser.add_argument('--augmentation', default="yes")
    parser.add_argument('--teacher', default="resnet18")
    parser.add_argument('--alpha', default=0.95, type=float)
    parser.add_argument('--temperature', default=6, type=int)
    parser.add_argument('--learning_rate', default=0.01, type=float)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--num_epochs', default=2000, type=int)
    parser.add_argument('--dropout_rate', default=0.0, type=float)
    parser.add_argument('--num_channels', default=32, type=int)
    parser.add_argument('--save_summary_steps', default=100, type=int)
    parser.add_argument('--num_workers', default=1, type=int)
    parser.add_argument('--model_dir', default='./',
                        help="Directory containing params.json")
    parser.add_argument('--restore_file', default=None,
                        help="Optional, name of the file in --model_dir \
                        containing weights to reload before training")
    parser.add_argument('--device', default='cuda:0')
    args = parser.parse_args()
    return args


if __name__ == '__main__':

    device = 'cuda:0'


    params = parse_args()


    random.seed(230)
    torch.manual_seed(230)
    if params.device: torch.cuda.manual_seed(230)




    logging.info("Loading the datasets...")


    '''if params.subset_percent < 1.0:
        train_dl = data_loader.fetch_subset_dataloader('train', params)
    else:
        train_dl = data_loader.fetch_dataloader('train', params)
        
    dev_dl = data_loader.fetch_dataloader('dev', params)'''
    mean = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
    std = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    transform_test = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    cifar100_training = torchvision.datasets.CIFAR100(root='./',
                                                      train=True, download=True,
                                                      transform=transform_train)
    train_dl = torch.utils.data.DataLoader(cifar100_training, batch_size=64, shuffle=True, drop_last=True)

    cifar100_testing = torchvision.datasets.CIFAR100(root='./',
                                                     train=False, download=True,
                                                     transform=transform_test)
    dev_dl = torch.utils.data.DataLoader(cifar100_testing, batch_size=64, shuffle=False, drop_last=True)
    '''mean = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
    std = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),  
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    transform_test = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    cifar100_training = torchvision.datasets.CIFAR100(root='/media/omnisky/Data/spiking/chonggou/prune/cifar100vgg16/', train=True, download=True,
                                                      transform=transform_train)
    train_dl  = torch.utils.data.DataLoader(cifar100_training, batch_size=64, shuffle=True,drop_last=True)


    cifar100_testing = torchvision.datasets.CIFAR100(root='/media/omnisky/Data/spiking/chonggou/prune/cifar100vgg16/', train=False, download=True,
                                                     transform=transform_test)
    dev_dl = torch.utils.data.DataLoader(cifar100_testing, batch_size=64, shuffle=False,drop_last=True)'''


    logging.info("- done.")



    if "distill" in params.model_version:

        if params.model_version == "snn_resnet18_distill":  # kd stu resnet
            model = spiking_resnet18().to(device)

            # optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0)
            optimizer = optim.SGD(model.parameters(), lr=params.learning_rate,
                                  momentum=0.9, weight_decay=0)

            metrics = {'accuracy': accuracy,
                       # could add more metrics such as accuracy for each token type
                       }



        elif params.model_version == 'snn_wrn16-4_distill':
            model = SNNWideResNet().to(device)
            print('snn_wrn16-2_distill')
            # optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
            optimizer = optim.SGD(model.parameters(), lr=params.learning_rate,
                                  momentum=0.9, weight_decay=0)
            print(params.learning_rate)

            metrics = {'accuracy': accuracy,
                       # could add more metrics such as accuracy for each token type
                       }

        if params.teacher == "snnresnet18":
            teacher_model = spiking_resnet18().to(device)
            teacher_model.to(device)

            teacher_checkpoint = torch.load('./pruned.pth.tar', map_location='cpu')
            teacher_model.load_state_dict(teacher_checkpoint['state_dict'])
            # ta=evaluate_kd(teacher_model, dev_dl, metrics, params)
            # print(ta)
            print('x')
        elif params.teacher == "snnwrn16-4":
            teacher_model = SNNWideResNet().to(device)
            # print("b")

            # teacher_checkpoint = torch.load('./WRN.pt')
            # teacher_model.load_state_dict(teacher_checkpoint)

            teacher_checkpoint = torch.load('./pruned.pth.tar', map_location='cpu')
            teacher_model.load_state_dict(teacher_checkpoint['state_dict'])


        # Train the model with KD
        logging.info("Experiment - model version: {}".format(params.model_version))
        logging.info("Starting training for {} epoch(s)".format(params.num_epochs))
        logging.info("First, loading the teacher model and computing its outputs...")
        train_and_evaluate_kd(model, teacher_model,train_dl, dev_dl, optimizer,
                              metrics, params, params.model_dir, params.restore_file)


    else:


        if params.model_version == "snnresnet18":
            model = spiking_resnet18().to(device)
            print('snnresnet18')


            #optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0)
            optimizer = optim.SGD(model.parameters(), lr=params.learning_rate,
                                  momentum=0.9, weight_decay=0)
            

            loss_fn = nn.CrossEntropyLoss()
            metrics = {'accuracy': accuracy,
                                 # could add more metrics such as accuracy for each token type
                                 }


        elif params.model_version == "snnvgg16":
            model = snnvgg16_bn().to(device)
            print('snn_vgg16')
            # optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
            optimizer = optim.SGD(model.parameters(), lr=params.learning_rate,
                                  momentum=0.9, weight_decay=0)

            loss_fn = nn.CrossEntropyLoss()
            metrics = {'accuracy': accuracy,
                       # could add more metrics such as accuracy for each token type
                       }

        elif params.model_version == "snnwrn16-4":
            model = SNNWideResNet().to(device)
            print('vgg16')
            # optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
            optimizer = optim.SGD(model.parameters(), lr=params.learning_rate,
                                  momentum=0.9, weight_decay=0)
            #print(params.learning_rate)

            loss_fn = nn.CrossEntropyLoss()
            metrics = {'accuracy': accuracy,
                       # could add more metrics such as accuracy for each token type
                       }

        # Train the model
        logging.info("Starting training for {} epoch(s)".format(params.num_epochs))
        train_and_evaluate(model, train_dl, dev_dl, optimizer, loss_fn,  metrics,params,
                           params.model_dir, params.restore_file)
