import argparse
import os
import numpy as np
import pandas as pd
import random

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data

from collections import OrderedDict as OD
from LayerAct import LA_HardSiLU, LA_SiLU
import data_augmentation 
from train_validate import validate, validate_10crop

from ResNet import resnet18, resnet50, resnet101 
from ResNet_small import resnet20, resnet32, resnet44


def resnet_set(name) :
    if name == 'resnet18' : return resnet18 
    elif name == 'resnet50' : return resnet50
    elif name == 'resnet101' : return resnet101

    elif name == 'resnet20' : return resnet20
    elif name == 'resnet32' : return resnet32
    elif name == 'resnet44' : return resnet44

def activation_set(name) : 
    if name == 'relu' : return nn.ReLU
    elif name == 'leakyrelu' : return nn.LeakyReLU
    elif name == 'prelu' : return nn.PReLU
    elif name == 'mish' : return nn.Mish
    elif name == 'silu' : return nn.SiLU
    elif name == 'hardsilu' : return nn.Hardswish
    elif name == 'la_silu' : return LA_SiLU
    elif name == 'la_hardsilu' : return LA_HardSiLU

def model_loader(model_name, activation, activation_params, rs, out_num) : 
    return resnet_set(model_name)(activation, activation_params, rs=rs, num_classes=out_num)

def folder_check(path, data_name, model_name) : 
    path_f = path + data_name + '/'
    path_m = path_f + model_name + '/'
    if data_name not in os.listdir(path) : 
        os.makedirs(path_f)
    if model_name not in os.listdir(path_f) : 
        os.makedirs(path_m)
    return path_m
    
random_seed = [11*i for i in range(1, 21)]

#######################################################################################################

if __name__ == '__main__' : 
    parser = argparse.ArgumentParser(description='')

    parser.add_argument('--data', '-d', default='CIFAR10')
    parser.add_argument('--model', '-m', default='resnet20')
    parser.add_argument('--activations', '-a', default='relu,leakyrelu,prelu,mish,silu,hardsilu,la_silu,la_hardsilu')
    parser.add_argument('--noise', '-n', default='None')
    parser.add_argument('--noise_param1', '-np1', default='')
    parser.add_argument('--noise_param2', '-np2', default='')

    parser.add_argument('--device', default=0, type=int)
    parser.add_argument('--crop', default='center')
    parser.add_argument('--start_trial', default=1, type=int)
    parser.add_argument('--end_trial', default=5, type=int)

    parser.add_argument('--alpha', default=1e-5)
    parser.add_argument('--batch_size', '-bs', default=128)
    parser.add_argument('--num_workers', '-nw', default=16)

    parser.add_argument('--data_path', '-dp', default='')
    parser.add_argument('--model_path', default='trained_models/')
    parser.add_argument('--save_path', default='result/')

    parser.add_argument('--resume', default=True, type=bool)
    parser.add_argument('--duplicate', default=True, type=bool)
    parser.add_argument('--save', default=True, type=bool)

    args = parser.parse_args()

    activation_list = [a for a in args.activations.split(',')]
    
    device = torch.device('cuda:{}'.format(args.device))
    model_path = folder_check(args.model_path, args.data, args.model)
    save_path = folder_check(args.save_path, args.data, args.model)

    if args.noise == 'gaussian' : 
        param1, param2 = float(args.noise_param1), float(args.noise_param2)
    elif args.noise == 'blur' : 
        param1 = (int(args.noise_param1.split(',')[0]), int(args.noise_param1.split(',')[1]))
        param2 = (int(args.noise_param2.split(',')[0]), int(args.noise_param2.split(',')[1]))
    else : 
        param1, param2 = 0, 0

    for activation_name in activation_list : 
        activation = activation_set(activation_name)
        activation_params = {'alpha' : args.alpha} if 'la_' in activation_name else {}    # parameter alpha of LayerAct functions for stable training

        for trial in range(args.start_trial, args.end_trial+1) : 
            rs = random_seed[trial-1]
            random.seed(rs)
            np.random.seed(rs)
            torch.manual_seed(rs)
            cudnn.deterministic = True
            cudnn.benchmark = False

            file_name = '{}_{}'.format(activation_name, trial)
            if args.data == 'CIFAR10' : 
                train_loader, val_loader, test_loader = data_augmentation.load_CIFAR10(
                    args.data_path, args.noise, param1, param2, args.batch_size, args.num_workers, rs
                    )
                in_channel, H, W, out_num = 3, 32, 32, 10 
            elif args.data == 'CIFAR100' : 
                train_loader, val_loader, test_loader = data_augmentation.load_CIFAR100(
                    args.data_path, args.noise, param1, param2, args.batch_size, args.num_workers, rs
                    )
                in_channel, H, W, out_num = 3, 32, 32, 100 
            elif args.data == 'ImageNet' : 
                train_loader, val_loader, test_loader = data_augmentation.load_CIFAR100(
                    args.data_path, args.noise, param1, param2, args.batch_size, args.num_workers, rs, args.crop
                    )
                in_channel, H, W, out_num = 3, 224, 224, 1000 
            else : 
                raise Exception('Dataset should be "CIFAR10", "CIFAR100", and "ImageNet"')
                    
            model = model_loader(args.model, activation, activation_params, rs, out_num)
            model.to(device)
            criterion = nn.CrossEntropyLoss().to(device)
            
            trained = torch.load(model_path + file_name + '.pth.tar', map_location=device)
            try : 
                model.load_state_dict(trained)
            except : 
                trained_ = OD([(k.split('module.')[-1], trained[k]) for k in trained.keys()])
                model.load_state_dict(trained_)
            if args.crop == '10crop' : 
                test_loss, test_acc1, test_acc5 = validate_10crop(test_loader, model, criterion, device)
            else : 
                test_loss, test_acc1, test_acc5 = validate(test_loader, model, criterion, device)
            print("{} | {} | {} | Test |  acc1 {} | acc5 {}".format(args.model, trial, activation_name, test_acc1, test_acc5), end = '\n')