#!/usr/bin/env python
# coding: utf-8

# # Imports

import argparse
import collections
import math
import time

import numpy as np
import scipy.io as sio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn import metrics, preprocessing
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix

from torchsummary import summary
import torch_optimizer as optim2


import geniter
import record
import Utils
from dataset import *
from  model import *
from train_helper import *


# # Setting Params

def make_args_parser():
    parser = argparse.ArgumentParser(description='Training for HSI')
    parser.add_argument(
        '-d', '--dataset', dest='dataset', default='IN', help="Name of dataset.")
    parser.add_argument(
        '-o',
        '--optimizer',
        dest='optimizer',
        default='adam',
        help="Name of optimizer.")
    parser.add_argument(
        '-e', '--epoch', type=int, dest='epoch', default=200, help="No of epoch")
    parser.add_argument(
        '-i', '--iter', type=int, dest='iter', default=3, help="No of iter")
    parser.add_argument(
        '-p', '--patch', type=int, dest='patch', default=4, help="Length of patch")
    parser.add_argument(
        '-k',
        '--kernel',
        type=int,
        dest='kernel',
        default=24,
        help="Length of kernel")
    parser.add_argument(
        '-vs',
        '--valid_split',
        type=float,
        dest='valid_split',
        default=0.9,
        help="Percentage of validation split.")
    return parser

parser = make_args_parser()
args = parser.parse_args()

PARAM_DATASET = args.dataset  # UP,IN,SV, KSC
PARAM_EPOCH = args.epoch
PARAM_ITER = args.iter
PATCH_SIZE = args.patch
PARAM_VAL = args.valid_split
PARAM_OPTIM = args.optimizer
PARAM_KERNEL_SIZE = args.kernel

# # Data Loading
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# for Monte Carlo runs
seeds = [1331, 1332, 1333, 1334, 1335, 1336, 1337, 1338, 1339, 1340, 1341]
ensemble = 1

global Dataset  # UP,IN,SV, KSC
dataset = PARAM_DATASET  # input('Please input the name of Dataset(IN, UP, SV, KSC):')
Dataset = dataset.upper()



# # Pytorch Data Loader Creation

data_hsi, gt_hsi, TOTAL_SIZE, TRAIN_SIZE, VALIDATION_SPLIT = load_dataset(
    Dataset, PARAM_VAL)
print(data_hsi.shape)
image_x, image_y, BAND = data_hsi.shape
# data: shape (H * W, C)
data = data_hsi.reshape(
    np.prod(data_hsi.shape[:2]), np.prod(data_hsi.shape[2:]))
# gt: shape (H * W)
gt = gt_hsi.reshape(np.prod(gt_hsi.shape[:2]), )
CLASSES_NUM = max(gt)
print('The class numbers of the HSI data is:', CLASSES_NUM)

print('-----Importing Setting Parameters-----')
ITER = PARAM_ITER
PATCH_LENGTH = PATCH_SIZE
lr, num_epochs, batch_size = 0.001, 200, 32
loss = torch.nn.CrossEntropyLoss()

H, W, C = data_hsi.shape

img_rows = 2 * PATCH_LENGTH + 1
img_cols = 2 * PATCH_LENGTH + 1
img_channels = data_hsi.shape[2]
INPUT_DIMENSION = C
ALL_SIZE = H * W
VAL_SIZE = int(TRAIN_SIZE)
TEST_SIZE = TOTAL_SIZE - TRAIN_SIZE

KAPPA = []
OA = []
AA = []
TRAINING_TIME = []
TESTING_TIME = []
ELEMENT_ACC = np.zeros((ITER, CLASSES_NUM))

data = preprocessing.scale(data)
data_ = data.reshape(H, W, C)
whole_data = data_
# do zero padding
# padded_data: shape (H + 2p, W = 2p, C)
padded_data = np.lib.pad(
    whole_data, ((PATCH_LENGTH, PATCH_LENGTH), (PATCH_LENGTH, PATCH_LENGTH),
                 (0, 0)),
    'constant',
    constant_values=0)




model = S3KAIResNet(BAND, CLASSES_NUM, 2, PARAM_KERNEL_SIZE = PARAM_KERNEL_SIZE).cuda()

summary(model, input_data=(1, img_rows, img_cols, BAND), verbose=1)




# # Training

for index_iter in range(ITER):
    print('iter:', index_iter)
    #define the model
    net = S3KAIResNet(band = BAND, classes = CLASSES_NUM, 
        reduction = 2, PARAM_KERNEL_SIZE = PARAM_KERNEL_SIZE)

    if PARAM_OPTIM == 'diffgrad':
        optimizer = optim2.DiffGrad(
            net.parameters(),
            lr=lr,
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=0)  # weight_decay=0.0001)
    if PARAM_OPTIM == 'adam':
        optimizer = optim.Adam(
            net.parameters(),
            lr=1e-3,
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=0)
    time_1 = int(time.time())
    np.random.seed(seeds[index_iter])
    # train_indices, test_indices = select(gt)
    # train_indices: a list of pixel id for training pixel
    # test_indices:  a list of pixel id for testing  pixel
    train_indices, test_indices = sampling(VALIDATION_SPLIT, gt)
    _, total_indices = sampling(1, gt)

    TRAIN_SIZE = len(train_indices)
    print('Train size: ', TRAIN_SIZE)
    TEST_SIZE = TOTAL_SIZE - TRAIN_SIZE
    print('Test size: ', TEST_SIZE)
    VAL_SIZE = int(TRAIN_SIZE)
    print('Validation size: ', VAL_SIZE)

    print('-----Selecting Small Pieces from the Original Cube Data-----')
    train_iter, valida_iter, test_iter, all_iter = geniter.generate_iter(
        TRAIN_SIZE, train_indices, TEST_SIZE, test_indices, TOTAL_SIZE,
        total_indices, VAL_SIZE, whole_data, PATCH_LENGTH, padded_data,
        INPUT_DIMENSION, 16, gt)  #batchsize in 1

    tic1 = time.time()
    train(
        net,
        train_iter,
        valida_iter,
        loss,
        optimizer,
        device,
        epochs=PARAM_EPOCH)
    toc1 = time.time()

    pred_test = []
    tic2 = time.time()
    with torch.no_grad():
        for X, y in test_iter:
            # print('Shape of X', X.shape, 'Shape of y', y.shape)
            # X = X.permute(0, 3, 1, 2)
            X = X.to(device)
            net.eval()
            y_hat = net(X)
            pred_test.extend(np.array(net(X).cpu().argmax(axis=1)))
    toc2 = time.time()
    collections.Counter(pred_test)
    gt_test = gt[test_indices] - 1

    overall_acc = metrics.accuracy_score(pred_test, gt_test[:-VAL_SIZE])
    confusion_matrix = metrics.confusion_matrix(pred_test, gt_test[:-VAL_SIZE])
    each_acc, average_acc = record.aa_and_each_accuracy(confusion_matrix)
    kappa = metrics.cohen_kappa_score(pred_test, gt_test[:-VAL_SIZE])

    torch.save(
        net.state_dict(), "./models/S3KAIResNetpatch_" + str(img_rows) + '_' +
        Dataset + '_split_' + str(VALIDATION_SPLIT) + '_lr_' + str(lr) +
        PARAM_OPTIM + '_kernel_' + str(PARAM_KERNEL_SIZE) + str(
            round(overall_acc, 3)) + '.pt')
    KAPPA.append(kappa)
    OA.append(overall_acc)
    AA.append(average_acc)
    TRAINING_TIME.append(toc1 - tic1)
    TESTING_TIME.append(toc2 - tic2)
    ELEMENT_ACC[index_iter, :] = each_acc

# # Map, Records
print("--------" + " Training Finished-----------")
record.record_output(
    OA, AA, KAPPA, ELEMENT_ACC, TRAINING_TIME, TESTING_TIME,
    './report/' + 'S3KAIResNetpatch:' + str(img_rows) + '_' + Dataset + 'split'
    + str(VALIDATION_SPLIT) + 'lr' + str(lr) + PARAM_OPTIM + '_kernel_' +
    str(PARAM_KERNEL_SIZE) + '.txt')

Utils.generate_png(
    all_iter, net, gt_hsi, Dataset, device, total_indices,
    './classification_maps/' + 'S3KAIResNetpatch:' + str(img_rows) + '_' +
    Dataset + 'split' + str(VALIDATION_SPLIT) + 'lr' + str(lr) + PARAM_OPTIM +
    '_kernel_' + str(PARAM_KERNEL_SIZE))
