import argparse
import collections
import math
import time

import numpy as np
import scipy.io as sio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn import metrics, preprocessing
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix

from torchsummary import summary
import torch_optimizer as optim2

import data_utils as du


def load_dataset(Dataset, split=0.9):
    data_path = '../dataset/'
    if Dataset == 'IN':
        mat_data = sio.loadmat(data_path + 'Indian_pines_corrected.mat')
        mat_gt = sio.loadmat(data_path + 'Indian_pines_gt.mat')
        data_hsi = mat_data['indian_pines_corrected']
        gt_hsi = mat_gt['indian_pines_gt']
        K = 200
        TOTAL_SIZE = 10249
        VALIDATION_SPLIT = split
        TRAIN_SIZE = math.ceil(TOTAL_SIZE * VALIDATION_SPLIT)

    if Dataset == 'UP':
        uPavia = sio.loadmat(data_path + 'PaviaU.mat')
        gt_uPavia = sio.loadmat(data_path + 'PaviaU_gt.mat')
        data_hsi = uPavia['paviaU']
        gt_hsi = gt_uPavia['paviaU_gt']
        K = 103
        TOTAL_SIZE = 42776
        VALIDATION_SPLIT = split
        TRAIN_SIZE = math.ceil(TOTAL_SIZE * VALIDATION_SPLIT)

    if Dataset == 'SV':
        SV = sio.loadmat(data_path + 'Salinas_corrected.mat')
        gt_SV = sio.loadmat(data_path + 'Salinas_gt.mat')
        data_hsi = SV['salinas_corrected']
        gt_hsi = gt_SV['salinas_gt']
        K = 15
        TOTAL_SIZE = 54129
        VALIDATION_SPLIT = split
        TRAIN_SIZE = math.ceil(TOTAL_SIZE * VALIDATION_SPLIT)

    if Dataset == 'KSC':
        SV = sio.loadmat(data_path + 'KSC.mat')
        gt_SV = sio.loadmat(data_path + 'KSC_gt.mat')
        data_hsi = SV['KSC']
        gt_hsi = gt_SV['KSC_gt']
        K = data_hsi.shape[2]
        TOTAL_SIZE = 5211
        VALIDATION_SPLIT = split
        TRAIN_SIZE = math.ceil(TOTAL_SIZE * VALIDATION_SPLIT)

    shapeor = data_hsi.shape
    # data_hsi: shape (H * W, C = K)
    data_hsi = data_hsi.reshape(-1, data_hsi.shape[-1])
    # data_hsi: shape (H * W, C)
    # Do PCA transform before classification
    data_hsi = PCA(n_components=K).fit_transform(data_hsi)
    shapeor = np.array(shapeor)
    shapeor[-1] = K
    # data_hsi: shape (H, W, C)
    data_hsi = data_hsi.reshape(shapeor)

    return data_hsi, gt_hsi, TOTAL_SIZE, TRAIN_SIZE, VALIDATION_SPLIT


def apply_pca(train_data, test_data):
    '''
    Compute PCA along spectral dimention for training data
    and transform test data
    Return:
        pca: PCA model
        train_data_pca: (H, W, C) -> (1096, 715, 102)
        test_data_pca: (B, h, w, C) -> (8, 128, 128, 102)
    '''
    assert train_data.shape[-1] == test_data.shape[-1]
    train_shape = train_data.shape
    test_shape = test_data.shape
    C = train_data.shape[-1]

    # train_data_: shape (num_train, C)
    train_data_ = train_data.reshape(-1, C)
    test_data_ = test_data.reshape(-1, C)

    pca = PCA(n_components=C)
    train_data_pca = pca.fit_transform(train_data_).reshape(train_shape)

    test_data_pca = pca.transform(test_data_).reshape(test_shape)
    return pca, train_data_pca, test_data_pca


def scale_data(train_data, test_data):
    assert train_data.shape[-1] == test_data.shape[-1]
    train_shape = train_data.shape
    test_shape = test_data.shape
    C = train_data.shape[-1]

    # train_data_: shape (num_train, C)
    train_data_ = train_data.reshape(-1, C)
    test_data_ = test_data.reshape(-1, C)

    scaler = preprocessing.StandardScaler().fit(train_data_)

    train_data_norm = scaler.transform(train_data_).reshape(train_shape)
    test_data_norm = scaler.transform(test_data_).reshape(test_shape)
    return scaler, train_data_norm, test_data_norm



def load_rs_dataset(num_band = 102, dataset = "pavia_centra"):
    '''
    Return:
        train_msi: shape (H, W, c) -> (1096, 715, 4)
        train_hsi: shape (H, W, C) -> (1096, 715, 102)
        train_labels: shape (H, W) -> (1096, 715)

        train_msi: shape (8, 128, 128, 4)
        train_hsi: shape (8, 128, 128, 102)
        train_labels: shape (8, 128, 128)
    '''
    if dataset == "pavia_centra":
        if num_band == 102:
            band_tag = ""
        else:
            band_tag = f"_band{num_band}"

        print("Load data")
        data_dir = "../../../dataset_preprocess/dataset/Pavia_Centre"
        train_hsi_dir = f"{data_dir}/train/HSI{band_tag}/"
        train_msi_dir = f"{data_dir}/train/MSI/"
        train_gt_dir = f"{data_dir}/train/GT/"

        test_hsi_dir = f"{data_dir}/test/HSI{band_tag}/"
        test_msi_dir = f"{data_dir}/test/MSI/"
        test_gt_dir = f"{data_dir}/test/GT/"

        train_msi = du.load_np_file(train_msi_dir, f"pavia_centre-msi_train.npy")
        train_hsi = du.load_np_file(train_hsi_dir, f"pavia_centre-hsi_train{band_tag}.npy")
        train_labels = du.load_np_file(train_gt_dir, "pavia_centre-gt_train.npy")

        test_labels = []
        for i in range(8):
            label = du.load_np_file(test_gt_dir, f"pavia_centre-gt_test_{i}.npy")
            test_labels.append(label)

        test_labels = np.concatenate(np.expand_dims(test_labels, 0), axis = 0)[:,:,:,0]

        test_hsis = []
        for i in range(8):
            hsi = du.load_np_file(test_hsi_dir, f"pavia_centre-hsi_test_{i}{band_tag}.npy")
            test_hsis.append(hsi)
        test_hsi = np.concatenate(np.expand_dims(test_hsis, 0), axis = 0)

        test_msis = []
        for i in range(8):
            msi = du.load_np_file(test_msi_dir, f"pavia_centre-msi_test_{i}.npy")
            test_msis.append(msi)
        test_msi = np.concatenate(np.expand_dims(test_msis, 0), axis = 0)

        classes = np.sort(list(set(np.unique(train_labels)) - set([-1])))
        num_class = len(classes)

    
        return train_msi, train_hsi, train_labels, \
                test_msi, test_hsi, test_labels, classes, num_class
    else:
        raise NotImplementedError

def sampling(proportion, ground_truth):
    '''
    Args:
        proportion: float, 0.9
        ground_truth: shape (H * W)
    '''
    train = {}
    test = {}
    labels_loc = {}
    # number of class
    m = max((ground_truth) )
    for i in range(m):
        indexes = [
            j for j, x in enumerate(ground_truth.ravel().tolist())
            if x == i + 1
        ]
        np.random.shuffle(indexes)
        labels_loc[i] = indexes
        if proportion != 1:
            nb_val = max(int((1 - proportion) * len(indexes)), 3)
        else:
            nb_val = 0
        train[i] = indexes[:nb_val]
        test[i] = indexes[nb_val:]
    train_indexes = []
    test_indexes = []
    for i in range(m):
        train_indexes += train[i]
        test_indexes += test[i]
    np.random.shuffle(train_indexes)
    np.random.shuffle(test_indexes)
    return train_indexes, test_indexes


def select(groundTruth):  #divide dataset into train and test datasets
    labels_loc = {}
    train = {}
    test = {}
    m = max(groundTruth)
    #amount = [3, 41, 29, 7, 14, 20, 2, 15, 3, 36, 64, 22, 4, 28, 10, 2]
    #amount = [43, 1387, 801, 230, 469, 710, 26, 463, 17, 936, 2391, 571, 201, 1237, 376, 91]
    if Dataset == 'IN':
        amount = [
            35, 1011, 581, 167, 344, 515, 19, 327, 12, 683, 1700, 418, 138,
            876, 274, 69
        ]  #IP 20%
    #amount = [6, 144, 84, 24, 50, 75, 3, 49, 2, 97, 247, 62, 22, 130, 38, 10]   #IP 20%
    if Dataset == 'UP':
        amount = [5297, 14974, 1648, 2424, 1076, 4026, 1046, 2950, 755]  #UP
    if Dataset == 'KSC':
        amount = [
            530, 165, 176, 170, 110, 161, 80, 299, 377, 283, 296, 341, 654
        ]  #KSC
    for i in range(m):
        indices = [
            j for j, x in enumerate(groundTruth.ravel().tolist()) if x == i + 1
        ]
        np.random.shuffle(indices)
        labels_loc[i] = indices
        nb_val = int(amount[i])
        train[i] = indices[:-nb_val]
        test[i] = indices[-nb_val:]
#    whole_indices = []
    train_indices = []
    test_indices = []
    for i in range(m):
        #        whole_indices += labels_loc[i]
        train_indices += train[i]
        test_indices += test[i]
    np.random.shuffle(train_indices)
    np.random.shuffle(test_indices)
    return train_indices, test_indices