import argparse
import collections
import math
import time
import os

import numpy as np
import scipy.io as sio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
import torch.optim as optim
from sklearn import metrics, preprocessing
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix, accuracy_score


from torchsummary import summary
import torch_optimizer as optim2
import wandb


import geniter
import record
import Utils
from dataset import *
from  model import *
from train_helper import *
import data_utils as du


# # Setting Params

def make_args_parser():
    parser = argparse.ArgumentParser(description='Training for HSI')
    parser.add_argument(
        '-d', '--dataset', dest='dataset', default='pavia_centra', help="Name of dataset.")
    parser.add_argument(
        '-s', '--save_data_dir', dest='save_data_dir', default='./res_data/', 
        help="the directory to save results and models.")
    parser.add_argument(
        '-o',
        '--optimizer',
        dest='optimizer',
        default='adam',
        help="Name of optimizer.")
    parser.add_argument(
        '-e', '--epoch', type=int, dest='epoch', default=200, help="No of epoch")
    parser.add_argument(
        '-i', '--iter', type=int, dest='iter', default=3, help="No of iter")
    parser.add_argument(
        '-p', '--patch', type=int, dest='patch', default=4, help="Length of patch")
    parser.add_argument(
        '-k',
        '--kernel',
        type=int,
        dest='kernel',
        default=24,
        help="Length of kernel")
    parser.add_argument(
        '-vs',
        '--valid_split',
        type=float,
        dest='valid_split',
        default=0.9,
        help="Percentage of validation split.")
    parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
    parser.add_argument("--batch_size", type=int, default=256, help="batch size")
    
    parser.add_argument("--epoch_save", type=int, default=10, help="model save frequency")
    
    parser.add_argument("--sample_size", type=int, default=1500, help="sample_size")
    parser.add_argument("--test_prop", type=float, default=0.01, help="test dataset tratified sampling proprotion")
    parser.add_argument("--num_band", type=int, default=102, help="num_band")
    
    return parser


 def main(args):
    # 1. load data
    op = load_rs_dataset(num_band = args.num_band, dataset = args.dataset)
    train_msi, train_hsi, train_labels, test_msi, test_hsi, test_labels, classes, num_class = op

    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Apple PCA
    pca, train_hsi_pca, test_hsi_pca = apply_pca(train_data = train_hsi, test_data = test_hsi)
    # Data Standardization
    '''
    train_hsi_norm: shape (1096, 715, 102)
    test_hsi_norm:  shape (8, 128, 128, 102)
    '''
    scaler, train_hsi_norm, test_hsi_norm = scale_data(train_data = train_hsi_pca, test_data = test_hsi_pca)

    B, h, w, C = test_hsi_norm.shape
    test_hsi_norm = test_hsi_norm.reshape(-1, w, C)
    test_labels = test_labels.reshape(-1, w)


    H, W, C = train_hsi_norm.shape

    img_rows = 2 * args.patch + 1
    img_cols = 2 * args.patch + 1



    net = S3KAIResNet(
        band = C, classes = num_class, reduction = 2, PARAM_KERNEL_SIZE = args.kernel).cuda()

    # summary(net, input_data=(1, img_rows, img_cols, C), verbose=1)


    optimizer = du.get_optim(net, opt = args.optimizer, lr = args.lr)
    time_1 = int(time.time())
    seeds = [1331, 1332, 1333, 1334, 1335, 1336, 1337, 1338, 1339, 1340, 1341]
    np.random.seed(seeds[0])




    print("Sample training points")
    train_label2idx = du.get_label2idx(labels = train_labels, 
                                    save_dir = args.save_data_dir, 
                                    sample_size = args.sample_size, data_flag = "train")
    train_indices = du.get_idxs_from_label2idx(train_label2idx)
    print("Sample testing points")
    test_label2idx = du.get_label2idx(labels = test_labels, 
                                    save_dir = args.save_data_dir, 
                                    sample_size = int(args.sample_size/4), data_flag = "test")
    test_sample_indices_1 = du.get_idxs_from_label2idx(test_label2idx)


    test_label2strafiedidx = du.get_stratified_sample_idx(labels = test_labels, save_dir = args.save_data_dir, 
                              prop = args.test_prop, 
                              data_flag = "test", 
                              classes = classes)
    test_sample_indices_2 = du.get_idxs_from_label2idx(test_label2strafiedidx)
    test_sample_indices = np.array(list(test_sample_indices_1) + list(test_sample_indices_2))

    # test_indices: all test image pixel ID's
    test_indices = np.arange(np.prod(test_hsi_norm.shape[:2]))


    train_dataset, train_loader = geniter.get_dataset_loader(data_indices = train_indices, 
                                                     whole_data = train_hsi_norm, 
                                                     patch_length = args.patch, 
                                                     labels = train_labels.reshape(-1), 
                                                     batch_size = args.batch_size, 
                                                     do_shuffle = True)

    test_sample_dataset, test_sample_loader = geniter.get_dataset_loader(data_indices = test_sample_indices, 
                                                     whole_data = test_hsi_norm, 
                                                     patch_length = args.patch, 
                                                     labels = test_labels.reshape(-1), 
                                                     batch_size = args.batch_size, 
                                                     do_shuffle = False)


    # The whole test image
    test_dataset, test_loader = geniter.get_dataset_loader(data_indices = test_indices, 
                                                     whole_data = test_hsi_norm, 
                                                     patch_length = args.patch, 
                                                     labels = test_labels.reshape(-1), 
                                                     batch_size = args.batch_size, 
                                                     do_shuffle = False)

    save_path = du.get_save_path(args)

    if not os.path.exists(save_path):
        os.mkdir(save_path)
    logger = du.setup_logging(log_file = os.path.join(save_path, "log.txt"), 
                                    console=True, filemode='a')

    wandb.init(project=os.path.basename(save_path), entity="specdecode")
    wandb.config = args


    # TRAIN MODEL
    tic1 = time.time()
    train_new(args = args,
                net = net,
                train_iter = train_loader,
                valida_iter = test_sample_loader,
                optimizer = optimizer,
                device = args.device,
                epochs = args.epoch,
                logger = logger,
                save_path = save_path,
                early_stopping=True,
                early_num=20)
    toc1 = time.time()

    # evaluate
    net, optimizer = load_best_model(save_path, net, optimizer)

    test_acc, test_loss, test_preds, test_gt = record.evaluate_accuracy_new(data_iter = test_loader, 
                             net = net, 
                             loss = torch.nn.CrossEntropyLoss(),
                             device = args.device)

    logger.info(f"Full test evalue: {test_acc}")



if __name__ == '__main__':
    parser = make_args_parser()
    args = parser.parse_args()

    mean(args)

