import os
import pandas as pd
import numpy as np
import torch
import random

from torch.utils.data import TensorDataset, DataLoader

from utils.transformations import transform
from utils.initial_design import load_data
from utils.active_learning import get_query


class Variable:
    def __init__(self, x, transformation_type):
        self.x = x
        self.transformation_type = transformation_type
        self.trans = None
        self.mu = None
        self.sigma = None

    def transform(self, mu=None, sigma=None):
        if mu is None:
            self.trans, self.mu, self.sigma = transform(self.x, method=self.transformation_type)
        else:
            self.trans, _, _ = transform(self.x, mu, sigma, method=self.transformation_type)

    def __call__ (self):
        return self.x


class Dataset:
    def __init__(self, x, y):
        self.x = x
        self.y = y

        self.x_trans, self.x_mu, self.x_sigma = None, None, None
        self.y_trans, self.y_mu, self.y_sigma = None, None, None

    def apply_data_transformation(x, based_on_var):
        if based_on_var.trans is None:
            based_on_var.trasform()

        return transform(x, based_on_var.mu, based_on_var.sigma, method=based_on_var.transformation_type)


class MyDataLoader:
    def __init__(self, args, oracle):
        self.args = args
        self.oracle = oracle
        self.pool_labeled = None
        self.oracle_labels = None

        self.true_mean = None
        self.true_std = None
        self.candidate_points = None

    def get_initial_data(self):
        if self.args.al_type == 'population_based':
            self.search_space = self.oracle.search_space()
            train_x, train_y = self.oracle.sample_initial_data(n_samples=self.args.initial_samples,
                                                               space_filling_design=self.args.space_filling_design,
                                                               seed=None)
            test_x, test_y = self.oracle.sample_initial_data(n_samples=self.args.test_samples,
                                                             space_filling_design=self.args.space_filling_design,
                                                             seed=None)

        else:
            self.search_space, self.oracle_labels = load_data(self.args, self.args.path_train_data)
            train_x, train_y, test_x, test_y, self.pool_labeled = sample_initial_dataset(self.args,
                                                                                    self.search_space,
                                                                                    self.oracle_labels,
                                                                                    path_test_data=self.args.path_test_data)
        #test_x, indices = torch.sort(test_x, dim=0)
        #test_y = test_y[indices]
        self.train = Dataset(train_x, train_y)
        self.test = Dataset(test_x, test_y)

    def compute_true_mean_and_stddev(self):
        toy_simulator = True
        if toy_simulator:
            # A simulator where we can access the true mean and stddev
            self.true_mean = torch.Tensor(self.oracle.mean(x=self.search_space))
            self.true_std = torch.Tensor(self.oracle.stddev(x=self.search_space))
        else:
            # If we use Mercury or Flitan
            raise NotImplementedError

    def transform(self):
        train_x_trans, self.x_mu, self.x_sigma = transform(self.train.x, method=self.args.transformation_x)
        train_y_trans, self.y_mu, self.y_sigma = transform(self.train.y, method=self.args.transformation_y)
        test_x_trans, _, _ = transform(self.test.x, self.x_mu, self.x_sigma, method=self.args.transformation_x)
        test_y_trans, _, _ = transform(self.test.y, self.y_mu, self.y_sigma, method=self.args.transformation_y)

        if self.args.outputs == 1:
            train_y_trans = train_y_trans.squeeze(-1)
            test_y_trans = test_y_trans.squeeze(-1)

        self.train_trans = Dataset(train_x_trans, train_y_trans)
        self.test_trans = Dataset(test_x_trans, test_y_trans)

    def make_dataloader(self):
        """
        Return PyTorch data loader or a tuple with data depending on the use of batches..
        """
        if self.args.model_type in ["deepgp", "deepgp_ngd", "dssp"]:
            # Make data loaders with batches
            gp_train_dataset = TensorDataset(self.train_trans.x, self.train_trans.y)
            self.gp_train_loader = DataLoader(gp_train_dataset, batch_size=len(self.train_trans.x), shuffle=True)
            gp_test_dataset = TensorDataset(self.test_trans.x, self.test_trans.y)
            self.gp_test_loader = DataLoader(gp_test_dataset, batch_size=len(self.test_trans.x), shuffle=True)

            # Original space
            gp_train_dataset = TensorDataset(self.train.x, self.train.y)
            self.gp_train_loader_ori = DataLoader(gp_train_dataset, batch_size=len(self.train.x), shuffle=True)
            gp_test_dataset = TensorDataset(self.test_.x, self.test.y)
            self.gp_test_loader_ori = DataLoader(gp_test_dataset, batch_size=len(self.test.x), shuffle=True)
        else:
            self.gp_train_loader = (self.train_trans.x, self.train_trans.y)
            self.gp_test_loader = (self.test_trans.x, self.test_trans.y)
            # Original space
            self.gp_train_loader_ori = (self.train.x, self.train.y)
            self.gp_test_loader_ori = (self.test.x, self.test.y)

    def get_candidate_points(self):
        # Points of interest for querying
        if self.args.al_type in ['pseudo_population_based', 'population_based']:
            candidate_points = self.search_space
            cp_unique = np.unique(candidate_points, axis=0)
            if cp_unique.shape[0] > 10000:
                #print("The candidate points is only a subset! (see data_handler.py)")
                # If it is too big, we take a subsample (random, LHS, sobol)
                subset_indices = np.random.choice(np.arange(cp_unique.shape[0]), size=10000, replace=False)
                candidate_points = cp_unique[subset_indices]

                # recompute candidate points
                toy_simulator=True
                if toy_simulator:
                    # A simulator where we can access the true mean and stddev
                    self.true_mean_cp = self.true_mean[subset_indices]
                    self.true_std_cp = self.true_std[subset_indices]
            else:
                self.true_std_cp = self.true_std
                self.true_mean_cp = self.true_mean

        else:  # args.al_type == "pool_based"
            # in the pool-based setting, we should exclude all the labeled data points from the search space
            # this is also applied in the oracle itself
            mask = np.full(len(self.search_space), True, dtype=bool)
            mask[self.pool_labeled] = False
            candidate_points = self.search_space[mask]

        self.candidate_points = candidate_points

        return candidate_points


def sample_initial_dataset(args, search_space, oracle_labels, path_test_data=None, initial_samples=None):
    """
    Create training and test sets given a search_space and the oracle_labels
    """

    # Make sure, we dont want more samples that is actually possible.
    # The initial numbers are initial_samples * k_samples * repeat_sampling, though this does not really makes sense
    # in the pool-based case.
    #assert len(search_space) < args.initial_samples * args.repeat_sampling

    if initial_samples is None:
        initial_samples = args.initial_samples

    print("Querying %d / %d initial data points." % (initial_samples * args.repeat_sampling, len(search_space)))
    #print("Querying %d / %d initial data points." % (initial_samples * 2, len(search_space)))

    pool_labeled = []

    if path_test_data is None:
        # Create test set in a static manner
        idx_test = np.linspace(9, 1019, args.test_samples, dtype='int16')
        test_x = torch.FloatTensor(search_space[idx_test])
        test_y = torch.FloatTensor(oracle_labels[idx_test])

        # Remember to add the point to the pool_labeled s.t. we not mix them with training set
        pool_labeled.extend(idx_test)
    else:
        # Create test set from separate data file
        test_ss, test_ol = load_data(args, path_test_data)
        assert args.test_samples <= test_ss.shape[0], f"Trying to use {args.test_samples} but there is only " \
                                                      f"{test_ss.shape[0]} samples in the data set."
        idx_test = np.linspace(0, len(test_ss) - 1, args.test_samples, dtype='int16')
        test_x = torch.FloatTensor(test_ss[idx_test])
        test_y = torch.FloatTensor(test_ol[idx_test])

    if args.seed:
        random.seed(args.seed)

    # Create training data
    assert len(search_space) >= initial_samples * args.repeat_sampling,\
        f"Trying to sample {initial_samples * args.repeat_sampling} data points, but there are only" \
        f" {len(search_space)} data points in the data set.\nMake sure that the product of initial samples and" \
        f" repeat sampling is below {len(search_space)}"

    new_points = sample_initial_inputs(initial_samples, search_space, method=args.space_filling_design)
    new_y, pool_labeled = get_query(args, new_points, search_space, oracle_labels, pool_labeled,
                                    k_samples=initial_samples,
                                    beta_sampling=-1,
                                    repeat_sampling=args.repeat_sampling, #2
                                    seed=False)

    new_x = search_space[pool_labeled[-(initial_samples * args.repeat_sampling):]]
    #new_x = search_space[pool_labeled[-(initial_samples * 2):]]
    train_x = torch.FloatTensor(new_x)
    if len(train_x.shape) == 1:
        train_x = train_x.view(-1, 1)
    train_y = torch.FloatTensor(new_y)

    """
    train_x = sample_initial_inputs(initial_samples, search_space, method=args.space_filling_design)  # * args.repeat_sampling
    train_y = []
    for i in train_x:
        y_new, pool_labeled = oracle(i, search_space, oracle_labels, pool_labeled, single_point=True)
        train_y.append(y_new)

    train_x = torch.FloatTensor(train_x)
    train_y = torch.FloatTensor(train_y)

    """

    return train_x, train_y, test_x, test_y, pool_labeled

