import argparse
import os
import math
from functools import partial

import yaml
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import re
import matplotlib.pyplot as plt
import pandas as pd

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

import datasets
import models
import utils





def get_pred_filepath(model_path, scale = 8, num_band = 102):
    model_dir = os.path.dirname(model_path)
    for f in os.listdir(model_dir):
        if f.startswith("model_pred_scale") and f.endswith(".pkl"):
            m = re.search('model_pred_scale(?P<scale>\d+)_band(?P<band>\d+).pkl', f)
            if m is not None:
                s = int(m.group('scale'))
                b = int(m.group('band'))
                if scale == s and num_band == b:
                    return os.path.join(model_dir, f)
    return None


def get_sample_idx_per_class(train_labels_, sample_size = 2000):
    un_ids, un_cnts = np.unique(train_labels_, return_counts = True)
    if sample_size > np.min(un_cnts):
        raise Exception('The sample size is larger than the minumum number of itmes per class')
    label2idx = {}
    for uid in un_ids:
        if uid != -1:
            id_list = np.where(train_labels_ == uid)[0]
            sample_ids = np.random.choice(id_list, size=sample_size, replace=False)
            label2idx[uid] = sample_ids
    return label2idx

def train_rf(sample_size, num_band, label2idx, train_labels, train_hsi):
    train_labels_ = train_labels.reshape(-1)

    H,W,C = train_hsi.shape
    train_hsi_ = train_hsi.reshape(-1, C)

    rf_feas = []
    rf_labels = []
    for key, value in label2idx.items():
        labels = train_labels_[value]
        fe = train_hsi_[value]
        rf_feas.append(fe)
        rf_labels.append(labels)

    rf_feas = np.concatenate(rf_feas, axis = 0) 
    rf_labels = np.concatenate(rf_labels, axis = 0) 


    clf = RandomForestClassifier(random_state=0, criterion = 'gini', 
                                 n_estimators  =300, max_features = "sqrt")
    clf.fit(rf_feas, rf_labels)
    return clf

def rf_pred_class(clf, img):
    B, C, H, W = img.shape
    # img_: shape (B, H, W, C)
    img_ = np.transpose(img, (0, 2, 3 ,1))
    pred_clas = []
    for i in tqdm(range(B)):
        pred_cla = clf.predict(img_[i].reshape(-1, C))
        pred_clas.append(pred_cla.reshape(H, W))
    pred_clas_mat = np.concatenate(np.expand_dims(pred_clas, axis = 0), axis = 0)
    return pred_clas_mat

def get_rf_pred(clf, msi_clf, sample_size, num_band, test_hsi, test_msi, 
    model_dict, test_labels, 
    save_dir, scale_list = [2, 3, 4, 8]): 
    # save HSI prediction
    gt_pred_clas_mat = rf_pred_class(clf, img = np.transpose(test_hsi, (0, 3, 1,2) ) )

    rf_gt_file = os.path.join(save_dir, f"test_gt_rf_pred_cla_band{num_band}_{sample_size}.pkl")
    utils.pickle_dump(gt_pred_clas_mat, rf_gt_file)
    print(f"Save HSI predition to {rf_gt_file}")

    # save MSI prediction
    msi_pred_clas_mat = rf_pred_class(msi_clf, img = np.transpose(test_msi, (0, 3, 1,2) ) )

    rf_msi_file = os.path.join(save_dir, f"test_msi_rf_pred_cla_{sample_size}.pkl")
    utils.pickle_dump(msi_pred_clas_mat, rf_msi_file)
    print(f"Save MSI predition to {rf_msi_file}")



    for scale in tqdm(scale_list):
        for model in [f'LIIF_C{num_band}', 'SSIF-SE', 'SSIF-SME', 'SSIF-SRE']:
            pred_file = get_pred_filepath(model_path = model_dict[model], scale = scale, num_band = num_band)
            pred_dict = utils.pickle_load(pred_file)

            pred_clas_mat = rf_pred_class(clf, img = pred_dict['pred'])
            # gt_pred_clas_mat = rf_pred_class(clf, img = pred_dict['gt'])

            rf_pred_file = pred_file.replace(".pkl", f"_rf_pred_cla_{sample_size}.pkl")
            utils.pickle_dump(pred_clas_mat, rf_pred_file)
            print(f'Scale: {scale}  model: {model} Save predition to {rf_pred_file}')

def get_rf_eval_df(sample_size, num_band, model_dict, test_labels, 
    save_dir, scale_list = [2, 3, 4, 8]):
    columns = ['model', 'scale', 'num_band',  'model_path', 'pred_file', 'acc']
    svm_df = pd.DataFrame(columns = columns)

    for scale in scale_list:
        for model in [f'LIIF_C{num_band}', 'SSIF-SE', 'SSIF-SME', 'SSIF-SRE']:
            pred_file = get_pred_filepath(model_path = model_dict[model], scale = scale, num_band = num_band)
            rf_pred_file = pred_file.replace(".pkl", f"_rf_pred_cla_{sample_size}.pkl")
            pred_clas_mat = utils.pickle_load(rf_pred_file)
            B, H, W = pred_clas_mat.shape
            
            acc = accuracy_score(y_true = pred_clas_mat.reshape(-1), y_pred = test_labels[:, :H, :W].reshape(-1))
            new_df = pd.DataFrame(columns = columns)
            row = [model, scale, num_band, model_dict[model], pred_file, acc]
            new_df.loc[0] = row
            svm_df = svm_df.append(new_df)

    # compute HSI eval
    rf_gt_file = os.path.join(save_dir, f"test_gt_rf_pred_cla_band{num_band}_{sample_size}.pkl")
    gt_pred_clas_mat = utils.pickle_load(rf_gt_file)
    new_df = pd.DataFrame(columns = columns)
    acc = accuracy_score(y_true = gt_pred_clas_mat.reshape(-1), y_pred = test_labels.reshape(-1))
    row = ["GT", None, num_band, None, rf_gt_file, acc]
    new_df.loc[0] = row
    svm_df = svm_df.append(new_df)

    # compute MSI eval
    rf_msi_file = os.path.join(save_dir, f"test_msi_rf_pred_cla_{sample_size}.pkl")
    msi_pred_clas_mat = utils.pickle_load(rf_msi_file)
    new_df = pd.DataFrame(columns = columns)
    acc = accuracy_score(y_true = msi_pred_clas_mat.reshape(-1), y_pred = test_labels.reshape(-1))
    row = ["MSI", None, num_band, None, rf_msi_file, acc]
    new_df.loc[0] = row
    svm_df = svm_df.append(new_df)


    svm_df.to_csv(os.path.join(save_dir,f"rf_band{num_band}_{sample_size}.csv"))

    return svm_df

def make_args_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--sample_size", type=int, default=1500, help="sample_size")
    parser.add_argument("--num_band", type=int, default=102, help="num_band")
    return parser

if __name__ == '__main__':
    model_list = ["RCAN + AWAN",
    "AWAN + RCAN",
    "AWAN + SSPSR",
    "RC/AW+MoG-DCN",
    "SSJSR",
    "US3RN",
    "LIIF",
    "SSIF-SE",
    "SSIF-SME",
    "SSIF-SRE"]

    model_dict = {
        "LIIF_C13": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN13-MAX13-SAM-liif-rdn-256-mlp-H512_512_512_512-LR0.000100-L1/epoch-best.pth",
        "LIIF_C26": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN26-MAX26-SAM-liif-rdn-256-mlp-H512_512_512_512-LR0.000100-L1/epoch-best.pth",
        "LIIF_C51": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN51-MAX51-SAM-liif-rdn-256-mlp-H512_512_512_512-LR0.000100-L1/epoch-best.pth",
        "LIIF_C102": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-liif-rdn-256-mlp-H512_512_512_512-LR0.000100-L1/epoch-best.pth",
        "SSIF-SE":  "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN13-MAX102-SAM13-liif-rdn-256-banddec-mlp-H512_512_512-512-bandposenc-band_se_mlp-32-1.00-0.000100-H512-LR0.000100-L1/epoch-best.pth",
        "SSIF-SME": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN13-MAX102-SAM13-liif-rdn-256-banddec-mlp-H512_512_512-512-bandposenc-band_sme_mlp-32-1.00-0.000100-H512-LR0.000100-L1/epoch-best.pth",
        "SSIF-SRE": "./save/pavia_centra/train_rdn-liif/PAVIA_CENTRA-TSM8-MD-BANDMIN13-MAX102-SAM13-liif-rdn-256-banddec-mlp-H512_512_512-512-bandposenc-band_sre_mlp-32-1.00-0.000100-H512-LR0.000100-L1/epoch-best.pth"
    }

    save_dir = "./save/pavia_centra/train_rdn-liif/RF/"
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)

    parser = make_args_parser()
    args = parser.parse_args()

    sample_size = args.sample_size
    num_band = args.num_band

    if num_band == 102:
        band_tag = ""
    else:
        band_tag = f"_band{num_band}"

    print("Load data")
    data_dir = "../dataset_preprocess/dataset/Pavia_Centre"
    train_hsi_dir = f"{data_dir}/train/HSI{band_tag}/"
    train_msi_dir = f"{data_dir}/train/MSI/"
    train_gt_dir = f"{data_dir}/train/GT/"

    test_hsi_dir = f"{data_dir}/test/HSI{band_tag}/"
    test_msi_dir = f"{data_dir}/test/MSI/"
    test_gt_dir = f"{data_dir}/test/GT/"

    train_msi = utils.load_np_file(train_msi_dir, f"pavia_centre-msi_train.npy")
    train_hsi = utils.load_np_file(train_hsi_dir, f"pavia_centre-hsi_train{band_tag}.npy")
    train_labels = utils.load_np_file(train_gt_dir, "pavia_centre-gt_train.npy")

    test_labels = []
    for i in range(8):
        label = utils.load_np_file(test_gt_dir, f"pavia_centre-gt_test_{i}.npy")
        test_labels.append(label)

    test_labels = np.concatenate(np.expand_dims(test_labels, 0), axis = 0)

    test_hsis = []
    for i in range(8):
        hsi = utils.load_np_file(test_hsi_dir, f"pavia_centre-hsi_test_{i}{band_tag}.npy")
        test_hsis.append(hsi)
    test_hsi = np.concatenate(np.expand_dims(test_hsis, 0), axis = 0)

    test_msis = []
    for i in range(8):
        msi = utils.load_np_file(test_msi_dir, f"pavia_centre-msi_test_{i}.npy")
        test_msis.append(msi)
    test_msi = np.concatenate(np.expand_dims(test_msis, 0), axis = 0)

    print("Sample training points")
    label2idx_path = f"{save_dir}/train_label2sampleidx_{sample_size}.pkl"
    if not os.path.exists(label2idx_path): 
        train_labels_ = train_labels.reshape(-1)
        label2idx = get_sample_idx_per_class(train_labels_, sample_size = sample_size)
        utils.pickle_dump(label2idx, label2idx_path)
    else:
        label2idx = utils.pickle_load(label2idx_path)


    print("Train RF")
    clf = train_rf(sample_size, num_band, label2idx, train_labels, train_hsi)

    print("Train MSI RF")
    msi_clf = train_rf(sample_size, num_band, label2idx, train_labels, train_msi)

    scale_list = [2, 3, 4, 8]

    print("Compute RF Prediction")
    get_rf_pred(clf, msi_clf, sample_size, num_band, test_hsi, test_msi, 
        model_dict, test_labels, 
        save_dir, scale_list = scale_list)
    

    print("Save RF Eval")
    rf_df = get_rf_eval_df(sample_size, num_band, model_dict, test_labels, 
        save_dir, scale_list = scale_list)