import json
import pandas as pd
import numpy as np

from method.model_profile import ModelProfile
from method.dataset_profile import DatasetProfile
from utils.config import MODEL_LIST, TASK_INFO, ALL_DATASETS


class MatrixManager:
    def __init__(self):
        with open("./data/BPMF_LVLM/result_summary.json", 'r') as f:
            result_summary = json.load(f)

        result_summary = pd.DataFrame(result_summary).T
        # result_summary = result_summary[[col for col in result_summary.columns if "SEED_2" in col]]
        self.result_summary = result_summary

    def get_metric_info(self):
        metric_info = ["acc" if "Whole dataset" not in col else "BART" 
                       for col in self.result_summary.columns]
        return metric_info

    # Define a function for splitting train/test data.
    def split_train_test(self, data, percent_test=0.1):
        """Split the data into train/test sets.
        :param int percent_test: Percentage of data to use for testing. Default 10.
        """
        n, m = data.shape  # # users, # movies
        N = n * m  # # cells in matrix

        # Prepare train/test ndarrays.
        train = data.copy()
        test = np.ones(data.shape) * np.nan

        # Draw random sample of training data to use for testing.
        tosample = np.where(~np.isnan(data))
        train, test, test_size = self.mask_randomly(tosample, train, test, mask_ratio=percent_test)
        train_size = np.sum(~np.isnan(data)) - test_size  # and remainder for training

        # Verify everything worked properly
        assert train_size == N - np.isnan(train).sum()
        assert test_size == N - np.isnan(test).sum()

        # Return train set and test set
        return train, test

    def load_data_for_pmf(self, percent_test=0.2):
        # Train-test split
        train, test = self.split_train_test(self.result_summary.values.copy(), percent_test=percent_test)

        # Normalization
        train, test, mu, sigma = self.normalize_data(train, test)

        self.train = train
        self.test = test
        self.mu = mu
        self.sigma = sigma

        return train, test, mu, sigma
    
    def load_data_for_pmf_acc_only(self, percent_test=0.2):
        data = self.result_summary.values.copy()
        metric_info = self.get_metric_info()
        acc_subset = [i for i in range(len(metric_info)) if metric_info[i] == "acc"]
        data = data[:, acc_subset]

        # Train-test split
        train, test = self.split_train_test(data, percent_test=percent_test)

        # Normalization
        train, test, mu, sigma = self.normalize_data(train, test, separate_acc_bart_normalization=False)

        self.train = train
        self.test = test
        self.mu = mu
        self.sigma = sigma

        return train, test, mu, sigma

    def normalize_data(self, train, test, separate_acc_bart_normalization=True):
        mu, sigma = np.nanmean(train, axis=0, keepdims=True), np.nanstd(train, axis=0, keepdims=True)

        # Few samples to estimate
        num_samples = (~np.isnan(train)).sum(axis=0, keepdims=True)
        mu[num_samples < 5] = np.nan
        sigma[num_samples < 5] = np.nan
        
        if separate_acc_bart_normalization:
            metric_info = self.get_metric_info()
            acc_subset = [i for i in range(len(metric_info)) if metric_info[i] == "acc"]
            bart_subset = [i for i in range(len(metric_info)) if metric_info[i] == "BART"]

            # Use global mean and std to fill in missing values
            acc_mean, acc_std = np.nanmean(train[:, acc_subset]), np.nanstd(train[:, acc_subset])
            bart_mean, bart_std = np.nanmean(train[:, bart_subset]), np.nanstd(train[:, bart_subset])
            global_mean = np.array([[acc_mean if metric_info[i] == "acc" else bart_mean for i in range(len(metric_info))]])
            global_std = np.array([[acc_std if metric_info[i] == "acc" else bart_std for i in range(len(metric_info))]])
        else:
            global_mean = np.nanmean(train, axis=(0,1), keepdims=True)
            global_mean = np.tile(global_mean, (1, train.shape[1]))
            global_std = np.nanstd(train, axis=(0,1), keepdims=True)
            global_std = np.tile(global_std, (1, train.shape[1]))
                              
        mu[np.isnan(mu)] = global_mean[np.isnan(mu)]
        sigma[np.isnan(sigma)] = global_std[np.isnan(sigma)]

        assert np.isnan(mu).sum() + np.isnan(sigma).sum() == 0

        train = (train - mu) / sigma
        test = (test - mu) / sigma
        return train, test, mu, sigma
    
    def mask_randomly(self, tosample, train, test, mask_ratio=0.1):
        # Draw random sample of training data to use for testing
        idx_pairs = list(zip(tosample[0], tosample[1]))  # tuples of row/col index pairs

        masked_size = int(len(idx_pairs) * mask_ratio) # normal mask
        indices = np.arange(len(idx_pairs))  # indices of index pairs
        sample = np.random.choice(indices, replace=False, size=masked_size)

        # Transfer random sample from train set to test set.
        for idx in sample:
            idx_pair = idx_pairs[idx]
            test[idx_pair] = train[idx_pair]  # transfer to test set
            train[idx_pair] = np.nan  # remove from train set

        return train, test, masked_size


    def load_data_unbalanced(self, highly_masked=0.1, highly_masked_ratio=0.9, normal_masked_ratio=0.2):
        data = self.result_summary.values.copy()

        # Train-test split
        n, m = data.shape  # # users, # movies
        N = n * m  # # cells in matrix

        # Prepare train/test ndarrays.
        train = data.copy()
        test = np.ones(data.shape) * np.nan

        # Choose a subset of columns and rows to highly mask
        basic_data = data.copy()
        highly_masked_cols = np.random.choice(np.arange(m), int(m * highly_masked), replace=False)
        highly_masked_rows = np.random.choice(np.arange(n), int(n * highly_masked), replace=False)
        basic_data[:, highly_masked_cols] = np.nan
        basic_data[highly_masked_rows, :] = np.nan

        # Normally masked
        train, test, normally_masked_size = self.mask_randomly(np.where(~np.isnan(basic_data)), train, test, mask_ratio=normal_masked_ratio)
        # Highly masked
        train, test, highly_masked_size = self.mask_randomly(np.where(np.isnan(basic_data)), train, test, mask_ratio=highly_masked_ratio)
        
        # Verify everything worked properly
        assert normally_masked_size + highly_masked_size == np.isnan(train).sum() - np.isnan(data).sum()
        assert normally_masked_size + highly_masked_size == N - np.isnan(test).sum()

        # Normalization
        train, test, mu, sigma = self.normalize_data(train, test)

        self.train = train
        self.test = test
        self.mu = mu
        self.sigma = sigma

        return train, test, mu, sigma, highly_masked_cols, highly_masked_rows

    def get_model_profiles(self, model_profile_cont):
        model_profile = ModelProfile(model_profile_cont)

        model_profile_mat = []
        for model_info in MODEL_LIST:
            model_key = model_info['store_model_path']
            model_profile_mat.append(model_profile.get_profile(model_key))

        model_profile_mat = np.array(model_profile_mat)
       
        print("Load model profile: ", model_profile_mat.shape)

        self.model_profile_mat = model_profile_mat

        return model_profile_mat
    
    def get_dataset_profiles(self, dataset_profile_cont):
        dataset_profile = DatasetProfile(dataset_profile_cont)

        dataset_profile_mat = []
        for dataset_key in ALL_DATASETS:
            dataset_profile_mat.append(dataset_profile.get_profile(dataset_key))

        dataset_profile_mat = np.array(dataset_profile_mat)
       
        print("Load dataset profile: ", dataset_profile_mat.shape)

        self.dataset_profile_mat = dataset_profile_mat

        return dataset_profile_mat

    def split_train_test_3d(self, data, percent_test=0.1):
        n, m, k = data.shape  # # users, # movies, # scores
        N = n * m * k  # cells in matrix

        # Prepare train/test ndarrays.
        train = data.copy()
        test = np.ones_like(data) * np.nan

        # Draw random sample of training data to use for testing.
        tosample = np.where(~np.isnan(data))
        # Draw random sample of training data to use for testing
        idx_pairs = list(zip(tosample[0], tosample[1], tosample[2]))

        test_size = int(len(idx_pairs) * percent_test) # normal mask
        indices = np.arange(len(idx_pairs))  # indices of index pairs
        sample = np.random.choice(indices, replace=False, size=test_size)

        # Transfer random sample from train set to test set.
        for idx in sample:
            idx_pair = idx_pairs[idx]
            test[idx_pair] = train[idx_pair]  # transfer to test set
            train[idx_pair] = np.nan  # remove from train set

        train_size = np.sum(~np.isnan(data)) - test_size  # and remainder for training

        # Verify everything worked properly
        assert train_size == N - np.isnan(train).sum()
        assert test_size == N - np.isnan(test).sum()

        # Return train set and test set
        return train, test
    
    def load_data_for_ptf(self, percent_test=0.2, subset=None):
        all_result_mat = self.get_all_result_mat(subset)

        # Train-test split
        train, test = self.split_train_test_3d(all_result_mat, percent_test=percent_test)

        # Normalization
        mu, sigma = np.nanmean(train, axis=0, keepdims=True), np.nanstd(train, axis=0, keepdims=True)
        # Few samples to estimate
        num_samples = (~np.isnan(train)).sum(axis=0, keepdims=True)
        mu[num_samples < 5] = np.nan
        sigma[num_samples < 5] = np.nan

        # Use global mean and std to fill in missing values
        global_mean = np.nanmean(train, axis=(0,1), keepdims=True)
        global_mean = np.tile(global_mean, (1, train.shape[1], 1))
        mu[np.isnan(mu)] = global_mean[np.isnan(mu)]

        global_std = np.nanstd(train, axis=(0,1), keepdims=True)
        global_std = np.tile(global_std, (1, train.shape[1], 1))
        sigma[np.isnan(sigma)] = global_std[np.isnan(sigma)]

        train = (train - mu) / sigma
        test = (test - mu) / sigma

        self.train = train
        self.test = test
        self.mu = mu
        self.sigma = sigma

        return train, test, mu, sigma

    def get_all_result_mat(self, subset=None):
        with open("./data/BPMF_LVLM/all_result_summary.json", 'r') as f:
            all_result_summary = json.load(f)

        all_result_mat = np.empty((len(MODEL_LIST), len(ALL_DATASETS), 6))
        for i, model_info in enumerate(MODEL_LIST):
            model_name = model_info['store_model_path']

            task_data_all = all_result_summary[model_name]
            for j, dataset in enumerate(ALL_DATASETS):
                perf = task_data_all[dataset]

                all_result_mat[i, j, 0] = perf.get("acc", np.nan)
                all_result_mat[i, j, 1] = perf.get("prec", np.nan)
                all_result_mat[i, j, 2] = perf.get("rec", np.nan)
                all_result_mat[i, j, 3] = perf.get("f1", np.nan)
                all_result_mat[i, j, 4] = perf.get("bart", np.nan)
                all_result_mat[i, j, 5] = perf.get("bert", np.nan)

        if subset is not None:
            all_result_mat = all_result_mat[:, :, subset]
        return all_result_mat
    
