import argparse
import collections
import math
import time
import os

import numpy as np
import scipy.io as sio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
import torch.optim as optim
from sklearn import metrics, preprocessing
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix, accuracy_score

from tqdm import tqdm
import re

from torchsummary import summary
import torch_optimizer as optim2
import wandb


import geniter
import record
import Utils
from dataset import *
from  model import *
from train_helper import *
import data_utils as du



# # Setting Params

def make_args_parser():
    parser = argparse.ArgumentParser(description='Training for HSI')
    parser.add_argument(
        '-d', '--dataset', dest='dataset', default='pavia_centra', help="Name of dataset.")
    parser.add_argument(
        '-s', '--save_data_dir', dest='save_data_dir', default='./res_data/', 
        help="the directory to save results and models.")

    parser.add_argument(
        '-t', '--test_data_path', dest='test_data_path', default='./res_data/', 
        help="the directory to test_data_path")

    parser.add_argument(
        '-o',
        '--optimizer',
        dest='optimizer',
        default='adam',
        help="Name of optimizer.")
    parser.add_argument(
        '-e', '--epoch', type=int, dest='epoch', default=200, help="No of epoch")
    parser.add_argument(
        '-i', '--iter', type=int, dest='iter', default=3, help="No of iter")
    parser.add_argument(
        '-p', '--patch', type=int, dest='patch', default=4, help="Length of patch")
    parser.add_argument(
        '-k',
        '--kernel',
        type=int,
        dest='kernel',
        default=24,
        help="Length of kernel")
    parser.add_argument(
        '-vs',
        '--valid_split',
        type=float,
        dest='valid_split',
        default=0.9,
        help="Percentage of validation split.")
    parser.add_argument("--lr", type=float, default=0.001, help="learning rate")
    parser.add_argument("--batch_size", type=int, default=256, help="batch size")
    
    parser.add_argument("--epoch_save", type=int, default=10, help="model save frequency")
    
    parser.add_argument("--sample_size", type=int, default=1500, help="sample_size")
    parser.add_argument("--test_prop", type=float, default=0.01, help="test dataset tratified sampling proprotion")
    parser.add_argument("--num_band", type=int, default=102, help="num_band")
    
    return parser

def get_pred_filepath(model_path, scale = 8, num_band = 102):
    model_dir = os.path.dirname(model_path)
    for f in os.listdir(model_dir):
        if f.startswith("model_pred_scale") and f.endswith(".pkl"):
            m = re.search('model_pred_scale(?P<scale>\d+)_band(?P<band>\d+).pkl', f)
            if m is not None:
                s = int(m.group('scale'))
                b = int(m.group('band'))
                if scale == s and num_band == b:
                    return os.path.join(model_dir, f)
    return None

def transform_data(data, pca, scaler):
    '''
    Args:
        data: shape (H, W, C)
    Return:
        data_norm: shape (H, W, C)
    '''
    # assert train_data.shape[-1] == test_data.shape[-1]
    data_shape = data.shape

    C = data_shape[-1]

    # data_: shape (N, C)
    data_ = data.reshape(-1, C)
    
    data_pca = pca.transform(data_)

    data_norm = scaler.transform(data_pca).reshape(data_shape)

    return data_norm


def get_model_pred(net, test_data, test_labels, pca, scaler, model_save_path,
        patch_length = 5, batch_size = 600, device = "cuda"):
    '''
    Args:
        test_data: shape (H, W, C)
    '''
    # PCA and Normalize data
    test_data_norm = transform_data(test_data, pca, scaler)
    # test_indices: all test image pixel ID's
    test_indices = np.arange(np.prod(test_data_norm.shape[:2]))

    # print(test_data_norm.shape)
    # The whole test image
    test_dataset, test_loader = geniter.get_dataset_loader(data_indices = test_indices, 
                                                     whole_data = test_data_norm, 
                                                     patch_length = patch_length, 
                                                     labels = test_labels.reshape(-1), 
                                                     batch_size = batch_size, 
                                                     do_shuffle = False)

    test_acc, test_loss, test_preds, test_gt = record.evaluate_accuracy_new(data_iter = test_loader, 
                             net = net, 
                             loss = torch.nn.CrossEntropyLoss(),
                             device = device)
    H, W, C = test_data_norm.shape

    res = {
        "pred": test_preds.reshape(H, W),
        "gt": test_gt.reshape(H, W),
        "acc": test_acc,
        'loss': test_loss,
        'model_save_path': model_save_path
    }
    return res



def get_model_pred_batch(net, img_root_dir, num_band, sample_size, test_prop, test_hsi, test_labels, 
    pca, scaler, model_save_path,
    model_dict, 
    save_dir, scale_list = [2, 3, 4, 8], patch_length = 5, batch_size = 600, device = "cuda"):
    # save HSI ground truth image prediction
    gt_pred = get_model_pred(net, 
            test_data = test_hsi, 
            test_labels = test_labels,
            pca = pca, scaler = scaler,  
            model_save_path = model_save_path,
            patch_length = patch_length, 
            batch_size = batch_size, device = device)

    gt_pred_file = os.path.join(save_dir, f"test_gt_a2s2k_pred_cla_band{num_band}_{sample_size}_{test_prop}.pkl")
    du.pickle_dump(gt_pred, gt_pred_file)
    print(f"Save HSI predition to {gt_pred_file}: Acc {gt_pred['acc']}")

#     eval_model_list = [f'LIIF_C{num_band}', 
#                         f"SSIF-M_C{num_band}", 
#                         "SSIF-M",
#                         "SSIF-Mc",
#                         "SSIF-MC",
#                         "SSIF-SE",
#                         "SSIF-SME",
#                         "SSIF-SRE"]
    eval_model_list = []
    
    for sample_id, dim in enumerate([4, 32, 52, 64]): #[2, 4, 8, 16, 32, 48, 52, 60, 64]
        for idx, (resp_f, samp_t)  in enumerate([('uniform', 'fix'), ('gaussian', 'fix'), ('gaussian', 'gaussian'), ('uniform', 'uniform')]):
            model_name = "SSIF-RF-{}-{}-{}-D".format(resp_f[0].upper(), samp_t[0].upper()[0], dim)
            eval_model_list.append(model_name)

    for scale in tqdm(scale_list):
        for model in eval_model_list:

            pred_file = get_pred_filepath(model_path = os.path.join(img_root_dir, model_dict[model]), 
                            scale = scale, num_band = num_band)
            # pred_file = os.path.join(img_root_dir, pred_file)
            pred_dict = du.pickle_load(pred_file)
            # img: shape (B, C, H, W)
            img = pred_dict['pred']
            # img: shape (B, H, W, C)
            img = np.transpose(img, (0, 2, 3, 1) )
            B, H, W, C = img.shape

            _, h, w = test_labels.reshape(B, -1, test_labels.shape[-1] ).shape
            if h > H or w > W:
                img_ = np.lib.pad(
                        img, ((0, 0), (0, h-H),
                                     (0, w-W), (0,0)),
                        'constant',
                        constant_values=0)
            else:
                img_ = img
            img = img_.reshape(-1, w, C)

            model_pred = get_model_pred(net, 
                test_data = img, 
                test_labels = test_labels,
                pca = pca, scaler = scaler,  
                model_save_path = model_save_path,
                patch_length = patch_length, 
                batch_size = batch_size, device = device)

            model_pred_file = pred_file.replace(".pkl", f"_a2s2k_pred_cla_{sample_size}_{test_prop}.pkl")
            du.pickle_dump(model_pred, model_pred_file)
            print(f"Scale: {scale}  model: {model} Save predition to {model_pred_file}, Acc: {model_pred['acc']}")





if __name__ == '__main__':
    model_list = ["RCAN + AWAN",
    "AWAN + RCAN",
    "AWAN + SSPSR",
    "RC/AW+MoG-DCN",
    "SSJSR",
    "US3RN",
    "LIIF",
    "SSIF-M*",
    "SSIF-M",
    "SSIF-Mc",
    "SSIF-MC",
    "SSIF-SE",
    "SSIF-SME",
    "SSIF-SRE"]

    model_dict = {
        "LIIF_C13": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN13-MAX13-SAM-liif-rdn-256-mlp-H512_512_512_512-LR0.000100-L1/epoch-best.pth",
        "LIIF_C26": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN26-MAX26-SAM-liif-rdn-256-mlp-H512_512_512_512-LR0.000100-L1/epoch-best.pth",
        "LIIF_C51": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN51-MAX51-SAM-liif-rdn-256-mlp-H512_512_512_512-LR0.000100-L1/epoch-best.pth",
        "LIIF_C102": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-liif-rdn-256-mlp-H512_512_512_512-LR0.000100-L1/epoch-best.pth",
        "SSIF-M_C13": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN13-MAX13-SAM-liif-rdn-256-banddec-mlp-H512_512_512-512-bandposenc-band_m_mlp-32-1.00-0.000100-H512-LR0.000100-L1/epoch-best.pth",
        "SSIF-M_C26": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN26-MAX26-SAM-liif-rdn-256-banddec-mlp-H512_512_512-512-bandposenc-band_m_mlp-32-1.00-0.000100-H512-LR0.000100-L1/epoch-best.pth",
        "SSIF-M_C51": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN51-MAX51-SAM-liif-rdn-256-banddec-mlp-H512_512_512-512-bandposenc-band_m_mlp-32-1.00-0.000100-H512-LR0.000100-L1/epoch-best.pth",
        "SSIF-M_C102": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN102-MAX102-SAM-liif-rdn-256-banddec-mlp-H512_512_512-512-bandposenc-band_m_mlp-32-1.00-0.000100-H512-LR0.000100-L1/epoch-best.pth",
        "SSIF-M": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN13-MAX102-SAM13-liif-rdn-256-banddec-mlp-H512_512_512-512-bandposenc-band_m_mlp-32-1.00-0.000100-H512-LR0.000100-L1/epoch-best.pth",
        "SSIF-Mc": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN13-MAX102-SAM13-liif-rdn-256-banddec-mlp-H512_512_512-512-bandposenc-band_mc1_mlp-32-1.00-0.000100-H512-LR0.000100-L1/epoch-best.pth",
        "SSIF-MC": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN13-MAX102-SAM13-liif-rdn-256-banddec-mlp-H512_512_512-512-bandposenc-band_mc_mlp-32-1.00-0.000100-H512-LR0.000100-L1/epoch-best.pth",
        "SSIF-SE":  "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN13-MAX102-SAM13-liif-rdn-256-banddec-mlp-H512_512_512-512-bandposenc-band_se_mlp-32-1.00-0.000100-H512-LR0.000100-L1/epoch-best.pth",
        "SSIF-SME": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN13-MAX102-SAM13-liif-rdn-256-banddec-mlp-H512_512_512-512-bandposenc-band_sme_mlp-32-1.00-0.000100-H512-LR0.000100-L1/epoch-best.pth",
        "SSIF-SRE": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN13-MAX102-SAM13-liif-rdn-256-banddec-mlp-H512_512_512-512-bandposenc-band_sre_mlp-32-1.00-0.000100-H512-LR0.000100-L1/epoch-best.pth"
    }
    
    for sample_id, dim in enumerate([4, 32, 52, 64]): #[2, 4, 8, 16, 32, 48, 52, 60, 64]
        for idx, (resp_f, samp_t)  in enumerate([('uniform', 'fix'), ('gaussian', 'fix'), ('gaussian', 'gaussian'), ('uniform', 'uniform')]):
            model_name = "SSIF-RF-{}-{}-{}-D".format(resp_f[0].upper(), samp_t[0].upper()[0], dim)
            model_path = "./save/pavia_centra/train_rdn-liif_band_rb_mlp/PAVIA_CENTRA-TSM8-MD-BANDMIN13-MAX102-SAM13-liif-rdn-256-banddec-mlp-H512_512_512-512-bandposenc-band_rb_mlp-32-1.00-0.000100-H512-bandnerf-img_band_dot1--identity-{dim:d}-{resp_f:s}-invsum-{samp_t:s}-LR0.000100-L1/epoch-best.pth".format(
                resp_f = resp_f,
                samp_t = samp_t,
                dim = dim
            )
            model_dict[model_name] = model_path

    
    scale_list = [2, 3, 4, 8]

    parser = make_args_parser()
    args = parser.parse_args()


    # 1. load data
    op = load_rs_dataset(num_band = args.num_band, dataset = args.dataset)
    train_msi, train_hsi, train_labels, test_msi, test_hsi, test_labels, classes, num_class = op
    print("Load dataset")
    args.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Apple PCA
    pca, train_hsi_pca, test_hsi_pca = apply_pca(train_data = train_hsi, test_data = test_hsi)
    # Data Standardization
    '''
    train_hsi_norm: shape (1096, 715, 102)
    test_hsi_norm:  shape (8, 128, 128, 102)
    '''
    scaler, train_hsi_norm, test_hsi_norm = scale_data(train_data = train_hsi_pca, test_data = test_hsi_pca)
    print("PCA and scaler data")
    B, h, w, C = test_hsi_norm.shape
    test_hsi_norm = test_hsi_norm.reshape(-1, w, C)
    test_labels = test_labels.reshape(-1, w)


    H, W, C = train_hsi_norm.shape

    img_rows = 2 * args.patch + 1
    img_cols = 2 * args.patch + 1






    net = S3KAIResNet(
        band = C, classes = num_class, reduction = 2, PARAM_KERNEL_SIZE = args.kernel).cuda()
    optimizer = du.get_optim(net, opt = args.optimizer, lr = args.lr)

    # load best model
    save_path = du.get_save_path(args)
    net, optimizer = load_best_model(save_path, net, optimizer)
    print("Load best land use classification model")

    get_model_pred_batch(net, 
                        img_root_dir = "../../",
                        num_band = args.num_band, 
                        sample_size = args.sample_size, 
                        test_prop = args.test_prop, 
                        test_hsi = test_hsi.reshape(-1, test_hsi.shape[-2], test_hsi.shape[-1]), 
                        test_labels = test_labels, 
                        pca = pca, 
                        scaler = scaler, 
                        model_save_path = save_path,
                        model_dict = model_dict, 
                        save_dir = args.save_data_dir, 
                        scale_list = [2, 3, 4, 8], 
                        # scale_list = [3], 
                        patch_length = args.patch, 
                        batch_size = args.batch_size, 
                        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

