import glob
import warnings

import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, accuracy_score

warnings.filterwarnings("ignore")

from lightwavesl1_functions import _generate_first_phase_kernels, _apply_kernels_feat_only

features_number = 4
from lightwaves_utils import ScalePerChannel, anova_feature_selection, mrmr_feature_selection, ScalePerChannelTrain
from sklearn.linear_model import RidgeClassifierCV
from sympy.utilities.iterables import multiset_permutations
from mpi4py import MPI


def transform(X, matrix, feat_mask, candidate_kernels, dilations):
    kernels = ckd_to_kernels(matrix, candidate_kernels, dilations)
    feats = _apply_kernels_feat_only(X, kernels)
    return feats[:, feat_mask]


def ckd_to_kernels(ckd, candidate_kernels, candidate_dilations):
    num_channel_indices = np.ones(ckd.shape[0], dtype=np.int32)
    channel_indices = ckd[:, 0]
    biases = np.zeros_like(num_channel_indices, dtype=np.float32)
    dilations = 2 ** candidate_dilations[ckd[:, 2]].flatten().astype(np.int32)
    lengths = np.array([len(candidate_kernels[i]) for i in ckd[:, 1]], dtype=np.int32)
    paddings = np.multiply((lengths - 1), dilations) // 2
    weights = candidate_kernels[ckd[:, 1]].flatten().astype(np.float32)

    return (
        weights,
        lengths,
        biases,
        dilations,
        paddings,
        num_channel_indices,
        channel_indices,
    )


def get_ckd_matrix_with_features(fidx, num_channels, n_candidate_kernels, n_dilations, n_features):
    return np.unique(
        np.array(np.unravel_index(fidx, (num_channels, n_candidate_kernels, n_dilations, n_features))).T,
        axis=0).astype(np.int32)


def get_fixed_candidate_kernels():
    kernel_set = np.array([np.array(p) for p in multiset_permutations(([2] * 3 + [-1] * 6))], dtype=np.float32)
    return kernel_set


dev_data = [0, 2, 6, 7, 8, 9, 10, 11, 12, 13, 17, 19, 20, 26, 27]
MAX_DILATION = 32
dir_prefix = "."

FIRST_PHASE_NUM_FEAT = 500
FIRST_PHASE_PRE_FINAL_FEAT_NUM = 3 * FIRST_PHASE_NUM_FEAT
N_BOOTSTRAP_SPLITS = 4
SAMPLE_SIZE = 1500
N_KERNELS = 84
VERBOSE = 0

metadata = pd.read_csv(F"{dir_prefix}/Datasets/DataDimensions.csv", header=None, sep='\n')
metadata = metadata[0].str.split(',', expand=True)

orig_comm = MPI.COMM_WORLD
orig_rank = orig_comm.Get_rank()
orig_n_nodes = orig_comm.Get_size()

dilations = np.arange(0, np.log2(MAX_DILATION) + 1).astype(np.int32)
n_dilations = dilations.size

if orig_rank == 0:
    print(_generate_first_phase_kernels.__module__)

for fileidx, filename in enumerate(sorted(glob.glob(F"{dir_prefix}/Datasets/*.npz"))):

    if fileidx == 15:  # InsectWingbeat
        continue

    orig_comm.Barrier()

    for seed in range(30):
        orig_comm.Barrier()
        np.random.seed(seed)

        dataset = filename.split("/")[-1].split(".")[0]
        data = np.load(filename)

        total_num_channels = data['train_x'].shape[1]

        if total_num_channels < orig_n_nodes:
            if orig_rank == 0:
                if VERBOSE:
                    print("Number of channels is smaller than number of nodes, reducing COMM to subset.")
            comm = orig_comm.Create_group(orig_comm.group.Incl(np.arange(total_num_channels).tolist()))
        else:
            comm = orig_comm

        channel_distribution = np.array_split(np.arange(total_num_channels), orig_n_nodes)

        my_channels = channel_distribution[orig_rank]

        if orig_rank < total_num_channels:

            rank = comm.Get_rank()
            n_nodes = comm.Get_size()

            train_x, train_y = data['train_x'][:, my_channels, :], data['train_y']

            train_shape = train_x.shape
            train_samples = train_shape[0]
            num_channels = train_shape[1]
            n_timepoints = train_shape[2]
            num_classes = len(np.unique(train_y))
            normalized = str(metadata[metadata[0] == dataset][6].item()).strip() == 'true'

            if not normalized:
                train_x = ScalePerChannelTrain(train_x)
            train_x = train_x.astype(np.float32)

            candidate_kernels = get_fixed_candidate_kernels()
            n_candidate_kernels = len(candidate_kernels)

            if rank == 0:
                if VERBOSE:
                    print(fileidx, dataset, seed)
                    print(candidate_kernels.shape[0] * n_dilations * total_num_channels)
            first_phase_kernels = _generate_first_phase_kernels(num_channels, candidate_kernels, dilations, seed)

            if train_samples > SAMPLE_SIZE:
                np.random.seed(seed)
                sample_idces = np.random.choice(train_samples, size=SAMPLE_SIZE, replace=False)
                train_samples = SAMPLE_SIZE
            else:
                sample_idces = slice(None)

            transform_features = _apply_kernels_feat_only(train_x[sample_idces, ...], first_phase_kernels)

            sel_feat_idces, sel_feat_scores = anova_feature_selection(
                transform_features.reshape((transform_features.shape[0], -1)), train_y[sample_idces],
                FIRST_PHASE_PRE_FINAL_FEAT_NUM,
                N_BOOTSTRAP_SPLITS,
                seed)

            ##Send feature scores to main node for comparison
            ##First send number of features to main node
            feat_count = np.array(sel_feat_idces.size).reshape((1, 1))
            feat_count_recvbuf = None
            if rank == 0:
                feat_count_recvbuf = np.empty([n_nodes], dtype='int')
            comm.Gather(feat_count, feat_count_recvbuf, root=0)

            displ = None
            feat_scores_recvbuf = None
            counts = None
            feat_count_sendbuf = sel_feat_scores.flatten()
            if rank == 0:
                displ = np.hstack((0, feat_count_recvbuf.flatten())).cumsum()[:-1]
                feat_scores_recvbuf = np.empty((feat_count_recvbuf.sum()), dtype=np.float32)
                counts = feat_count_recvbuf

            comm.Gatherv(feat_count_sendbuf, [feat_scores_recvbuf, counts, displ, MPI.FLOAT], root=0)

            if rank == 0:
                score_src_idces = []
                for i in range(n_nodes):
                    score_src_idces.extend([i] * feat_count_recvbuf[i])
                score_src_idces = np.array(score_src_idces)

                top_score_src_count = np.bincount(score_src_idces[np.argsort(feat_scores_recvbuf.flatten())[::-1]][
                                                  :FIRST_PHASE_PRE_FINAL_FEAT_NUM], minlength=n_nodes).astype(np.int32)

            else:
                top_score_src_count = np.empty(n_nodes, dtype=np.int32)

            comm.Bcast(top_score_src_count, root=0)

            sel_feat_idces = np.sort(sel_feat_idces[np.argsort(sel_feat_scores)[::-1]][:top_score_src_count[rank]])

            if (top_score_src_count == 0).any():
                if orig_rank == 0 and VERBOSE == 1:
                    print("Some nodes have 0 CKD selected, reducing COMM to subset.")

                new_comm = comm.Create_group(comm.group.Incl(np.where(top_score_src_count != 0)[0].tolist()))
            else:
                new_comm = comm

            if top_score_src_count[rank] > 0:
                rank = new_comm.Get_rank()
                n_nodes = new_comm.Get_size()

                ckdf = get_ckd_matrix_with_features(sel_feat_idces, num_channels, n_candidate_kernels, n_dilations,
                                                    features_number)
                ckdf[:, 0] = my_channels[ckdf[:, 0]]

                ##Send best kernels to main node for second comparison
                displ = None
                ckdf_recvbuf = None
                counts = None
                feat_sendbuf = ckdf.flatten()
                if rank == 0:
                    displ = np.hstack((0, top_score_src_count[top_score_src_count != 0].flatten())).cumsum()[:-1] * 4
                    ckdf_recvbuf = np.empty((4 * top_score_src_count.sum()), dtype=np.int32)
                    counts = top_score_src_count[top_score_src_count != 0] * 4

                new_comm.Gatherv(feat_sendbuf, [ckdf_recvbuf, counts, displ, MPI.INT], root=0)

                if rank == 0:
                    ckdf_recvbuf = ckdf_recvbuf.reshape((-1, 4))
                    test_y = data['test_y']
                    if not normalized:
                        full_train_x, full_test_x = ScalePerChannel(data['train_x'], data['test_x'])
                    else:
                        full_train_x, full_test_x = data['train_x'], data['test_x']

                    unique_ckdf_recvbuf = np.unique(ckdf_recvbuf[:, :-1], axis=0)
                    cand_kernels = ckd_to_kernels(unique_ckdf_recvbuf, candidate_kernels, dilations)
                    feat_mask = np.zeros((unique_ckdf_recvbuf.shape[0], features_number), dtype=bool)
                    sel_feat_per_k = list(pd.DataFrame(ckdf_recvbuf).groupby([0, 1, 2])[3].apply(list))
                    for i in range(feat_mask.shape[0]):
                        feat_mask[i, sel_feat_per_k[i]] = True

                    cand_feats = _apply_kernels_feat_only(full_train_x[sample_idces, :, :], cand_kernels)[:, feat_mask]

                    global_sel_feats_p2_idces, _, _ = \
                        mrmr_feature_selection(cand_feats,
                                               train_y[sample_idces],
                                               FIRST_PHASE_NUM_FEAT)

                    ckdf_recvbuf = ckdf_recvbuf[global_sel_feats_p2_idces, :]
                    kernel_matrix_final = np.unique(ckdf_recvbuf[:, :-1], axis=0)
                    feat_mask = np.zeros((kernel_matrix_final.shape[0], features_number), dtype=bool)
                    sel_feat_per_k = list(pd.DataFrame(ckdf_recvbuf).groupby([0, 1, 2])[3].apply(list))
                    for i in range(feat_mask.shape[0]):
                        feat_mask[i, sel_feat_per_k[i]] = True

                    train_tr = transform(full_train_x, kernel_matrix_final, feat_mask, candidate_kernels, dilations)
                    test_tr = transform(full_test_x, kernel_matrix_final, feat_mask, candidate_kernels, dilations)

                    final_classifier = RidgeClassifierCV(alphas=np.logspace(-3, 3, 10), normalize=True)
                    final_classifier.fit(train_tr, train_y)
                    y_hat = final_classifier.predict(test_tr)
                    f1_sc = np.round(f1_score(test_y, y_hat, average='weighted'), 3)
                    acc = np.round(accuracy_score(test_y, y_hat), 3)
                    print(seed, dataset, acc, f1_sc, FIRST_PHASE_PRE_FINAL_FEAT_NUM,
                          FIRST_PHASE_NUM_FEAT)
