import argparse
import time
import os
import sys
import numpy as np
import random
import shutil

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data

from LayerAct import LA_HardSiLU, LA_SiLU
import data_augmentation 
from train_validate import train, validate

from ResNet import resnet18, resnet50, resnet101 
from ResNet_small import resnet20, resnet32, resnet44


def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

def resnet_set(name) :
    if name == 'resnet18' : return resnet18 
    elif name == 'resnet50' : return resnet50
    elif name == 'resnet101' : return resnet101

    elif name == 'resnet20' : return resnet20
    elif name == 'resnet32' : return resnet32
    elif name == 'resnet44' : return resnet44

def activation_set(name) : 
    if name == 'relu' : return nn.ReLU
    elif name == 'leakyrelu' : return nn.LeakyReLU
    elif name == 'prelu' : return nn.PReLU
    elif name == 'mish' : return nn.Mish
    elif name == 'silu' : return nn.SiLU
    elif name == 'hardsilu' : return nn.Hardswish
    elif name == 'la_silu' : return LA_SiLU
    elif name == 'la_hardsilu' : return LA_HardSiLU

def model_loader(model_name, activation, activation_params, rs, out_num) : 
    return resnet_set(model_name)(activation, activation_params, rs=rs, num_classes=out_num)

def folder_check(path, data_name, model_name) : 
    path_f = path + data_name + '/'
    path_m = path_f + model_name + '/'
    if data_name not in os.listdir(path) : 
        os.makedirs(path_f)
    if model_name not in os.listdir(path_f) : 
        os.makedirs(path_m)
    return path_m
    
random_seed = [11*i for i in range(1, 21)]

#######################################################################################################

if __name__ == '__main__' : 
    parser = argparse.ArgumentParser(description='')

    parser.add_argument('--data', '-d', default='CIFAR10')
    parser.add_argument('--model', '-m', default='resnet20')
    parser.add_argument('--activation', '-a', default='relu')

    parser.add_argument('--device_ids', default='0')
    parser.add_argument('--output_device', default=0, type=int)
    parser.add_argument('--crop', default='center')
    parser.add_argument('--start_trial', default=1, type=int)
    parser.add_argument('--end_trial', default=5, type=int)

    parser.add_argument('--alpha', default=1e-1)
    parser.add_argument('--batch_size', '-bs', default=256)
    parser.add_argument('--num_workers', '-nw', default=16)
    parser.add_argument('--learning_rate', '-lr', default=0.1)
    parser.add_argument('--momentum', default=0.9)
    parser.add_argument('--weight_decay', '-wd', default=0.0001)    
    parser.add_argument('--max_iter', default=600000)
    parser.add_argument('--milestones', default='180000,360000,540000')

    parser.add_argument('--data_path', '-dp', default='')
    parser.add_argument('--save_path', default='trained_models/')

    parser.add_argument('--resume', default="True", type=str)
    parser.add_argument('--duplicate', default="False", type=str)
    parser.add_argument('--save', default="True", type=str)

    args = parser.parse_args()

    activation = activation_set(args.activation)
    activation_params = {'alpha' : args.alpha} if 'la_' in args.activation else {}    # parameter alpha of LayerAct functions for stable training

    milestones = [int(m) for m in args.milestones.split(',')]

    device_ids = [int(d) for d in args.device_ids.split(',')]
    output_device = torch.device('cuda:{}'.format(args.output_device))
    save_path = folder_check(args.save_path, args.data, args.model)

    resume = True if args.resume == 'True' else False
    duplicate = True if args.duplicate == 'True' else False
    save = True if args.save == 'True' else False

    for trial in range(args.start_trial, args.end_trial+1) : 
        rs = random_seed[trial-1]
        random.seed(rs)
        np.random.seed(rs)
        torch.manual_seed(rs)
        cudnn.deterministic = True
        cudnn.benchmark = False

        file_name = '{}_{}'.format(args.activation, trial)

        if not duplicate and '{}.pth.tar'.format(file_name) in os.listdir(save_path) :
            sys.exit('Model ({} | {} | {}) exists'.format(args.data, args.model, args.activation))    

        if args.data == 'CIFAR10' : 
            train_loader, val_loader, test_loader = data_augmentation.load_CIFAR10(args.data_path, 'None', '', '', args.batch_size, args.num_workers, rs)
            in_channel, H, W, out_num = 3, 32, 32, 10 
        elif args.data == 'CIFAR100' : 
            train_loader, val_loader, test_loader = data_augmentation.load_CIFAR100(args.data_path, 'None', '', '', args.batch_size, args.num_workers, rs)
            in_channel, H, W, out_num = 3, 32, 32, 100 
        elif args.data == 'ImageNet' : 
            train_loader, val_loader, test_loader = data_augmentation.load_CIFAR100(args.data_path, 'None', '', '', args.batch_size, args.num_workers, rs, args.crop)
            in_channel, H, W, out_num = 3, 224, 224, 1000 
        else : 
            raise Exception('Dataset should be "CIFAR10", "CIFAR100", and "ImageNet"')
                    
        model = model_loader(args.model, activation, activation_params, rs, out_num)
        model.to(torch.device('cuda'))
        model = nn.DataParallel(model, device_ids=device_ids, output_device=output_device)
        
        criterion = nn.CrossEntropyLoss().to(torch.device('cuda'))
        optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, last_epoch=0-1)

        print('model make', end='\n')
        best_model = None
        best_acc1 = 0
        start_time = time.time()
        start_iter = 0
        if resume and os.path.isfile(save_path + file_name + '_checkpoint.pth.tar') : 
            print('Resume', end='\r')
            checkpoint = torch.load(save_path + file_name + '_checkpoint.pth.tar', map_location=torch.device('cuda'))
            start_iter = checkpoint['iter']
            best_acc1 = checkpoint['best_acc1']
            best_model = checkpoint['best_model']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['scheduler'])

        iter = start_iter
        while iter < args.max_iter : 
            iter, lr_scheduler = train(train_loader, model, criterion, optimizer, lr_scheduler, torch.device('cuda'), iter, output_device=output_device)
            val_loss, val_acc1, val_acc5 = validate(val_loader, model, criterion, torch.device('cuda'), output_device=output_device)
            train_loss, train_acc1, train_acc5 = validate(train_loader, model, criterion, torch.device('cuda'), output_device=output_device)

            t = time.time()
            is_best = val_acc1 > best_acc1 
            best_acc1 = max(val_acc1, best_acc1)

            if is_best : 
                best_model = model.state_dict()
                best_iter = iter
                print(
                    'Updated | Iter {}/{} | {}% | {} min | {} min left | Train loss {} | top1 {} | top5 {} | val loss {} | top1 {} | top5 {}'.format(
                        iter, args.max_iter, round(100*(iter+1)/args.max_iter), round((t-start_time)/60), round((t-start_time)/60*((args.max_iter-iter-1)/(iter+1))), 
                        round(train_loss, 3), round(train_acc1.item(), 3), round(train_acc5.item(), 3),
                        round(val_loss, 3), round(val_acc1.item(), 3), round(val_acc5.item(), 3)
                        ) + ' '*10, end='\r'
                    )

            save_checkpoint(
                {
                    'iter' : iter + 1, 
                    'time' : t,
                    'state_dict' : model.state_dict(),
                    'best_model' : best_model,
                    'best_acc1' : best_acc1, 
                    'optimizer' : optimizer.state_dict(), 
                    'scheduler' : lr_scheduler.state_dict(), 
                }, is_best, save_path + file_name + '_checkpoint.pth.tar'
            )

            if iter > args.max_iter :  
                break

        if save : 
            torch.save(best_model, '{}.pth.tar'.format(save_path + file_name))
        model.load_state_dict(best_model)
        test_loss, test_acc1, test_acc5 = validate(test_loader, model, criterion, torch.device('cuda'), output_device=output_device)

        print("{} | {} | {} | Test |  acc1 {} | acc5 {}".format(args.model, trial, args.activation, test_acc1, test_acc5), end = '\n')




            