

import argparse
import logging
import os
import time
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from spikingjelly.clock_driven import functional
#from WideResNet import *
from torch.optim.lr_scheduler import StepLR
from torch.autograd import Variable
from tqdm import tqdm
import torchvision
#from pymodel import *
from spiking_resnet import *
from spiking_resnettea import *
#from snnresnet66 import *

#from spiking_resnet import *
#from resnet import *
#from snn_wrn import *
#from WideResNet import *
import utils

#import data_loader as data_loader

from evaluate import *
from torchvision import datasets, transforms

parser = argparse.ArgumentParser()




from tensorboardX import SummaryWriter

writer = SummaryWriter('./img2/')


def train(model, optimizer, loss_fn, dataloader, metrics, params):


    model.train()

    summ = []
    loss_avg = utils.RunningAverage()

    with tqdm(total=len(dataloader)) as t:
        for i, (train_batch, labels_batch) in enumerate(dataloader):

            if params.device:
                train_batch, labels_batch = train_batch.to(device), labels_batch.to(device)

            train_batch, labels_batch = Variable(train_batch), Variable(labels_batch)


            output_batch = model(train_batch)
            loss = loss_fn(output_batch, labels_batch)


            optimizer.zero_grad()
            loss.backward()


            optimizer.step()
            functional.reset_net(model)


            if i % params.save_summary_steps == 0:

                output_batch = output_batch.data.cpu().numpy()
                labels_batch = labels_batch.data.cpu().numpy()


                summary_batch = {metric: metrics[metric](output_batch, labels_batch)
                                 for metric in metrics}
                summary_batch['loss'] = loss.item()
                summ.append(summary_batch)


            loss_avg.update(loss.item())

            t.set_postfix(loss='{:05.3f}'.format(loss_avg()))
            t.update()


    metrics_mean = {metric: np.mean([x[metric] for x in summ]) for metric in summ[0]}
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items())
    logging.info("- Train metrics: " + metrics_string)


def train_and_evaluate(model, train_dataloader, val_dataloader, optimizer,
                       loss_fn,  metrics,params, model_dir, restore_file=None):

    if restore_file is not None:
        restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))
        utils.load_checkpoint(restore_path, model, optimizer)

    best_val_acc = 0.0

    #scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=320)
    scheduler = StepLR(optimizer, step_size=100, gamma=0.1)


    writer.flush()
    max=0.0

    for epoch in range(params.num_epochs):
        scheduler.step()


        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))


        train(model, optimizer, loss_fn, train_dataloader, metrics, params)


        val_acc = evaluate(model, loss_fn, val_dataloader,  metrics,params)

        
        print(val_acc)

        if val_acc > max:
            max = val_acc
            torch.save({
                
                'state_dict': model.state_dict(),
                
            }, os.path.join('./', 'bestrescifarimg2.pth.tar'))
        print(max)

        writer.add_scalar('test_accuracy', val_acc, epoch)




def train_kd(model, teacher_model,optimizer,  dataloader, metrics, params):

    model.train()

    summ = []
    loss_avg = utils.RunningAverage()

    for train_batch, labels_batch in tqdm(dataloader):


        train_batch, labels_batch = train_batch.to(device), labels_batch.to(device)

        train_batch, labels_batch = Variable(train_batch), Variable(labels_batch)
        #tt=train_batch
        #print(tt.shape)
        #tt1=train_batch
        #print(tt1.shape)
        #print(train_batch.shape)
        
        #print(train_batch.shape)
        with torch.no_grad():
            #print(tt1.shape)
            output_teacher_batch = teacher_model(train_batch)
        #print(train_batch.shape)
        output_batch = model(train_batch).to(device)



        




        alpha = params.alpha
        T = params.temperature
        #print(nn.KLDivLoss()(F.log_softmax(output_batch / T, dim=1),
                              #F.softmax(output_teacher_batch / T, dim=1)) * (alpha * T * T) )
        #print(F.cross_entropy(output_batch, labels_batch) * (1. - alpha))
        

        loss = nn.KLDivLoss()(F.log_softmax(output_batch / T, dim=1),
                              F.softmax(output_teacher_batch / T, dim=1)) * (alpha * T * T) + \
               F.cross_entropy(output_batch, labels_batch) * (1. - alpha)
        #loss = F.cross_entropy(output_batch, labels_batch)

        

        optimizer.zero_grad()
        loss.backward()
        '''#-----------------------------------------
        for k, m in enumerate(model.modules()):
            # print(k, m)
            if isinstance(m, nn.Conv2d):
                weight_copy = m.weight.data.abs().clone()
                mask = weight_copy.gt(0).float().to(device)
                m.weight.grad.data.mul_(mask)
                #print('layer index: {:d} \t total params: {:d} \t remaining params: {:d}'.
                #format(k, mask.numel(), int(torch.sum(mask))))
        #-----------------------------------------'''
        optimizer.step()
        functional.reset_net(model)
# -------------------------------------------------------------
    '''#pruning 
    total = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            total += m.weight.data.numel()
    conv_weights = torch.zeros(total)
    index = 0
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            size = m.weight.data.numel()
            conv_weights[index:(index+size)] = m.weight.data.view(-1).abs().clone()
            index += size

    y, i = torch.sort(conv_weights)
    thre_index = int(total * 0.3)
    thre = y[thre_index]
    pruned = 0
    print('Pruning threshold: {}'.format(thre))
    zero_flag = False
    for k, m in enumerate(model.modules()):
        if isinstance(m, nn.Conv2d):
            weight_copy = m.weight.data.abs().clone()
            mask = weight_copy.gt(thre).float().to(device)
            pruned = pruned + mask.numel() - torch.sum(mask)
            m.weight.data.mul_(mask)
            if int(torch.sum(mask)) == 0:
                zero_flag = True
            print('layer index: {:d} \t total params: {:d} \t remaining params: {:d}'.
                format(k, mask.numel(), int(torch.sum(mask))))
    print('Total conv params: {}, Pruned conv params: {}, Pruned ratio: {}'.format(total, pruned, pruned/total))'''
# -------------------------------------------------------------

       
    
       



def train_and_evaluate_kd(model,teacher_model, train_dataloader, val_dataloader, optimizer,
                           metrics, params, model_dir, restore_file=None):

    if restore_file is not None:
        restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar')
        logging.info("Restoring parameters from {}".format(restore_path))


    best_val_acc = 0.0



    #scheduler = StepLR(optimizer, step_size=150, gamma=0.1)

    scheduler = StepLR(optimizer, step_size=100, gamma=0.1)
    writer.flush()

    max=0.0
    for epoch in range(params.num_epochs):
        #scheduler.step()


        logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs))


        train_kd(model,teacher_model, optimizer,  train_dataloader,
                 metrics, params)
        scheduler.step()

        ta = evaluate_kd(model, val_dataloader, metrics, params)
        writer.add_scalar('test_accuracy', ta, epoch)
        print(ta)
        #print(ta)
        if ta> max:
            max = ta
        print(max)







def accuracy(outputs, labels):

    outputs = np.argmax(outputs, axis=1)
    return np.sum(outputs==labels)/float(labels.size)

def parse_args():
    parser.add_argument('--model_version', default='resnet18')
    parser.add_argument('--subset_percent', default=1.0, type=float)
    parser.add_argument('--augmentation', default="yes")
    parser.add_argument('--teacher', default="resnet18")
    parser.add_argument('--alpha', default=0.95, type=float)
    parser.add_argument('--temperature', default=6, type=int)
    parser.add_argument('--learning_rate', default=0.003, type=float)
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--num_epochs', default=320, type=int)
    parser.add_argument('--dropout_rate', default=0.0, type=float)
    parser.add_argument('--num_channels', default=32, type=int)
    parser.add_argument('--save_summary_steps', default=100, type=int)
    parser.add_argument('--num_workers', default=1, type=int)
    parser.add_argument('--model_dir', default='./',
                        help="Directory containing params.json")
    parser.add_argument('--restore_file', default=None,
                        help="Optional, name of the file in --model_dir \
                        containing weights to reload before training")
    parser.add_argument('--device', default='cuda:1')
    args = parser.parse_args()
    return args


if __name__ == '__main__':

    device = 'cuda:1'


    params = parse_args()


    random.seed(230)
    torch.manual_seed(230)
    if params.device: torch.cuda.manual_seed(230)




    logging.info("Loading the datasets...")


    '''if params.subset_percent < 1.0:
        train_dl = data_loader.fetch_subset_dataloader('train', params)
    else:
        train_dl = data_loader.fetch_dataloader('train', params)
        
    dev_dl = data_loader.fetch_dataloader('dev', params)'''
    '''mean = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
    std = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    transform_test = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    cifar100_training = torchvision.datasets.CIFAR100(root='./',
                                                      train=True, download=True,
                                                      transform=transform_train)
    train_dl = torch.utils.data.DataLoader(cifar100_training, batch_size=64, shuffle=True, drop_last=True)

    cifar100_testing = torchvision.datasets.CIFAR100(root='./',
                                                     train=False, download=True,
                                                     transform=transform_test)
    dev_dl = torch.utils.data.DataLoader(cifar100_testing, batch_size=64, shuffle=False, drop_last=True)'''
    '''mean = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
    std = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),  
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    transform_test = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize(mean, std)])

    cifar100_training = torchvision.datasets.CIFAR100(root='/media/omnisky/Data/spiking/chonggou/prune/cifar100vgg16/', train=True, download=True,
                                                      transform=transform_train)
    train_dl  = torch.utils.data.DataLoader(cifar100_training, batch_size=64, shuffle=True,drop_last=True)


    cifar100_testing = torchvision.datasets.CIFAR100(root='/media/omnisky/Data/spiking/chonggou/prune/cifar100vgg16/', train=False, download=True,
                                                     transform=transform_test)
    dev_dl = torch.utils.data.DataLoader(cifar100_testing, batch_size=64, shuffle=False,drop_last=True)'''
    num_label = 200
    data_dir='./tiny-imagenet-200/'
    batch_size=64
    normalize = transforms.Normalize((0.4802, 0.4481, 0.3975), (0.2770, 0.2691, 0.2821))
    transform_train = transforms.Compose(
        [ transforms.RandomHorizontalFlip(), transforms.ToTensor(),
         normalize, ])
    transform_test = transforms.Compose([transforms.ToTensor(), normalize, ])

    trainset = datasets.ImageFolder(root=os.path.join(data_dir, 'train'), transform=transform_train)
    testset = datasets.ImageFolder(root=os.path.join(data_dir, 'val'), transform=transform_test)
    train_dl = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=True)
    dev_dl = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, pin_memory=True)



    logging.info("- done.")



    if "distill" in params.model_version:


        if params.model_version == "snn_resnet18_distill":
            model = spiking_resnet18().to(device)
            print('resnet1818186666')

            # lr = 0.0001
            # print(lr)
            print('ssssssssssssonlyres')
            print(params.learning_rate)
            # optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0)
            optimizer = optim.Adam(model.parameters(), lr=params.learning_rate, weight_decay=1e-5)

            metrics = {'accuracy': accuracy,
                       # could add more metrics such as accuracy for each token type
                       }



        elif params.model_version == 'snn_wrn16-4_distill':

            model = SNNWideResNet().to(device)

            print('snn_wrn16-2_distill')

            # optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)

            optimizer = optim.SGD(model.parameters(), lr=params.learning_rate,

                                  momentum=0.9, weight_decay=1e-5)

            print(params.learning_rate)

            metrics = {'accuracy': accuracy,

                       # could add more metrics such as accuracy for each token type

                       }

        if params.teacher == "resnet18":
            print('tttttttttttttresres1111')
            
            
            
            
            teacher_model = spiking_resnet18tea().to(device)
            teacher_model.to(device)
        
            teacher_checkpoint = torch.load('./bestrescifar100.pth.tar',map_location='cpu')
            teacher_model.load_state_dict(teacher_checkpoint['state_dict'])
        elif params.teacher == "snnwrn16-4":
            teacher_model = SNNWideResNet().to(device)
            # print("b")

            # teacher_checkpoint = torch.load('./WRN.pt')
            # teacher_model.load_state_dict(teacher_checkpoint)

            teacher_checkpoint = torch.load('./pruned.pth.tar', map_location='cpu')
            teacher_model.load_state_dict(teacher_checkpoint['state_dict'])

            
            


        # Train the model with KD
        logging.info("Experiment - model version: {}".format(params.model_version))
        logging.info("Starting training for {} epoch(s)".format(params.num_epochs))
        logging.info("First, loading the teacher model and computing its outputs...")
        train_and_evaluate_kd(model, teacher_model,train_dl, dev_dl, optimizer,
                              metrics, params, params.model_dir, params.restore_file)


    else:


        if params.model_version == "snnresnet18":
            model = spiking_resnet18().to(device)
            print('resnet18')

            
            #lr = 0.0001
            #print(lr)
            print('ssssssssssssonly')
            optimizer = optim.Adam(model.parameters(), lr=params.learning_rate, weight_decay=1e-5)

            #optimizer = optim.SGD(model.parameters(), lr=params.learning_rate,
                                  #momentum=0.9, weight_decay=1e-5)
            
            
            
            
            
            # optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
            #optimizer = optim.SGD(model.parameters(), lr=params.learning_rate,
                                  #momentum=0.9, weight_decay=0)

            loss_fn = nn.CrossEntropyLoss()
            metrics = {'accuracy': accuracy,
                                 # could add more metrics such as accuracy for each token type
                                 }


        elif params.model_version == "snnvgg16":
            model = snnvgg16_bn().to(device)
            print('snn_vgg16')
            # optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
            optimizer = optim.SGD(model.parameters(), lr=params.learning_rate,
                                  momentum=0.9, weight_decay=0)

            loss_fn = nn.CrossEntropyLoss()
            metrics = {'accuracy': accuracy,
                       # could add more metrics such as accuracy for each token type
                       }

        elif params.model_version == "snnwrn16-4":
            model = SNNWideResNet().to(device)
            print('wrn16-4')
            # optimizer = optim.Adam(model.parameters(), lr=params.learning_rate)
            optimizer = optim.SGD(model.parameters(), lr=params.learning_rate,
                                  momentum=0.9, weight_decay=1e-5)
            #print(params.learning_rate)

            loss_fn = nn.CrossEntropyLoss()
            metrics = {'accuracy': accuracy,
                       # could add more metrics such as accuracy for each token type
                       }


        # Train the model
        logging.info("Starting training for {} epoch(s)".format(params.num_epochs))
        train_and_evaluate(model, train_dl, dev_dl, optimizer, loss_fn,  metrics,params,
                           params.model_dir, params.restore_file)

