import pandas as pd
import warnings
import numpy as np
import os
import copy
import scipy.stats

import torch
import torch.nn as nn
from sklearn.neighbors import NearestNeighbors
import time
import scipy.stats
from torch.optim import Adam

class MixRL(object):
    """MixRL class.
    
    Attributes:
        x_train: Features for training.
        y_train: Labels for training.
        x_valid: Features for validation.
        y_valid: Labels for validation.
        knn_option: A list that contains options of knn.
        reg_model: An initialized regression model.
        Mixup_value_net: Mixup value network trained by REINFORCE algorithm.
        batch_size_Mix: A batch size for the training Mixup value network.
        batch_size_reg_model: A batch size for the training regression model.
        epoch_Mix: An epoch for the training Mixup value network.
        epoch_reg_model: An epoch for the training regression model.
        lr_reg_model: A learning rate for the regression model.
        lambd: A mixing weight lambda for Mixup.
        reward_scale_factor: A factor for reward scaling.
        device_option: An option whether to use cpu(0) or gpu(1).
        early_stop_flag: A flag indicating the use of early-stopping or not.
        early_stop_patience: A constant used for early stopping patience if use early stopping.
    """
    
    def __init__(self, x_train, y_train, x_valid, y_valid, knn_option,
                 reg_model, Mixup_value_net, batch_size_MixRL, batch_size_reg_model,
                 epoch_MixRL, epoch_reg_model, lr_reg_model, lambd, reward_scale_factor, device_option, early_stop_flag, early_stop_patience):
        """Initialize MixRL class."""
        
        self.x_train = x_train
        self.y_train = y_train
        self.x_valid = x_valid
        self.y_valid = y_valid
        self.data_dim = np.shape(x_train)[1]
        self.label_dim = np.shape(y_train)[1]
        self.knn_option = knn_option
        self.knn_dim = len(self.knn_option) * 2
        self.reg_model = reg_model
        self.Mixup_value_net = Mixup_value_net
        self.batch_size_MixRL = batch_size_MixRL
        self.batch_size_reg_model = batch_size_reg_model
        self.epoch_MixRL = epoch_MixRL
        self.epoch_reg_model = epoch_reg_model
        self.lr_reg_model = lr_reg_model
        self.lambd = lambd
        self.reward_scale_factor = reward_scale_factor
 
        if device_option:
            self.device = torch.device('cuda')
        else:
            self.device = torch.device('cpu')
            
        self.es_flag = early_stop_flag
        if self.es_flag:
            self.es_p = early_stop_patience
        else:
            self.es_p = 0
        
        # Normalize x_train and y_train for training MixRL model and regression model
        x_max = np.max(self.x_train, axis = 0)
        x_min = np.min(self.x_train, axis = 0)
        y_max = np.max(self.y_train, axis = 0)
        y_min = np.min(self.y_train, axis = 0)
        self.x_train_norm = (self.x_train - x_min) / (x_max - x_min)
        self.y_train_norm = (self.y_train - y_min) / (y_max - y_min)
        
        self.knn_model_X = NearestNeighbors(n_neighbors=max(self.knn_option) + 1)
        self.knn_model_X.fit(self.x_train)
        self.knn_model_Y = NearestNeighbors(n_neighbors=max(self.knn_option) + 1)
        self.knn_model_Y.fit(self.y_train)
        
    def train_reg_model(self, reg_model, mixup_flag, mixup_idx):
        """A function for training the regression model.
        
        Args:
            reg_model: A regression model to be trained.
            mixup_flag: A flag indicating whether the regression model is trained on only existing training data (0)
              or is trained on existing training data with mixed data (1).
            mixup_idx: Pairs of indices of data to mix determined by MixRL.
              For examples, if [1, 2] is in mixup_idx, (x1, y1) and (x2, y2) will be mixed.
              
        Returns:
            A trained regression model.
            
        """
        
        batch_size = self.batch_size_reg_model
        optimizer = Adam(reg_model.parameters(), lr = self.lr_reg_model)
        mse_loss = nn.MSELoss()
        best_mse = 1e3
        label_dim = self.label_dim
        x_train = self.x_train
        y_train = self.y_train
        epochs = self.epoch_reg_model
        early_stop_flag = self.es_flag
        early_stop_patience = self.es_p
        early_stop_count = 0
        
        iteration = len(x_train)//batch_size
        if mixup_flag:
            train_data_idx = []
            for i in range(len(x_train)):
                train_data_idx.append((i, i))
            train_data_idx = np.array(train_data_idx)
            mixup_idx = np.vstack((mixup_idx, train_data_idx))
            
        for epoch in range(epochs):
            reg_model.train()
            losses = []
            if mixup_flag == 0:
                shuffle_idx = np.random.permutation(np.arange(len(x_train)))
                x_train_input = x_train[shuffle_idx]
                y_train_input = y_train[shuffle_idx]
                for idx in range(iteration):
                    x_input = x_train_input[idx*batch_size:(idx+1)*batch_size, :]
                    y_input = y_train_input[idx*batch_size:(idx+1)*batch_size, :]   
                    x_input = torch.tensor(x_input).float().to(self.device)
                    y_input = torch.tensor(y_input).float().to(self.device)

                    pred_Y = reg_model(x_input)
                    loss = mse_loss(pred_Y, y_input)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    losses.append(loss.cpu().data.numpy())
            else:
                shuffle_mixup_idx = np.random.permutation(mixup_idx)
                for idx in range(iteration) :
                    idx_1 = shuffle_mixup_idx[idx*batch_size:(idx+1)*batch_size, 0]
                    idx_2 = shuffle_mixup_idx[idx*batch_size:(idx+1)*batch_size, 1]

                    X1 = x_train[idx_1, :]
                    X2 = x_train[idx_2, :]
                    Y1 = y_train[idx_1, :]
                    Y2 = y_train[idx_2, :]

                    mixup_X = X1 * self.lambd + X2 * (1-self.lambd)
                    mixup_Y = Y1 * self.lambd + Y2 * (1-self.lambd)
                    mixup_X = torch.tensor(mixup_X).float().cuda()
                    mixup_Y = torch.tensor(mixup_Y).float().cuda()

                    pred_Y = reg_model(mixup_X)
                    loss = mse_loss(pred_Y, mixup_Y)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    losses.append(loss.cpu().data.numpy())
                

            reg_model.eval()
            x_valid_torch = torch.tensor(self.x_valid).float().to(self.device)
            y_valid_pred = reg_model(x_valid_torch).cpu().data.numpy()
            mse = np.square(np.subtract(y_valid_pred, self.y_valid)).mean()
            if early_stop_flag:
                if mse <= best_mse:
                    early_stop_count = 0
                    best_mse = mse
                    best_reg_model = copy.deepcopy(reg_model)
                else:
                    early_stop_count += 1
                if early_stop_count == early_stop_patience:
                    break
        if early_stop_flag:
            return best_reg_model
        else:
            return reg_model
        
    def cal_measures(self, model, x_test, y_test):
        """A function for calculating two measurements of the regression model, R squared and RMSE (Root Mean Square Error).
        
        Args:
            model: The regression model.
            x_test: Features of a dataset to be tested.
            y_test: Labels of a dataset to be tested.
            
              
        Returns:
            R squared and MSE of the regression model on a dataset to be tested.
            
        """
        
        label_dim = self.label_dim
        x_test_torch = torch.tensor(x_test).float().to(self.device)
        y_test_pred = model(x_test_torch).cpu().data.numpy()
        r_s_lst = []
        for j in range(label_dim):
            a = y_test_pred[:, j]
            b = y_test[:, j]
            r_s_lst.append(scipy.stats.pearsonr(a,b)[0] ** 2)
        r_s = np.mean(r_s_lst)
        mse = np.square(np.subtract(y_test_pred, y_test)).mean()
        return r_s, mse
        
    def mixup_knn(self, data_idx_list, k_idx_list):
        """A function for generating mixed data using Mixup with the highest valued knn options for each example.
        
        Args:
            data_idx_list: A list that contains the index of each example. 
            k_idx_list: A list that contains the index of the highest valued knn options.
            
            
        Example:
            Suppose the list of knn options is [1, 2, ..., 32].
            Also, say the indices of training data to be mixed are 10 and 2.
            Also, say the highest valued knn options for data 10 and data 2 are 7 (7nn with data distance), 40 (8nn with label distance).
            Then, data_idx_list = [10, 2] and k_idx_list = [7, 40].
            
              
        Returns:
            Pairs of indices of data to mix.
            For examples, if [1, 2] is in mixup_idx, (x1, y1) and (x2, y2) will be mixed when training regression model.
            
        """
        
        data_dim = self.data_dim
        label_dim = self.label_dim
        mixup_idx = []
        for i in range(len(data_idx_list)):
            k_idx = k_idx_list[i]
            distance_type = k_idx // (self.knn_dim // 2)
            t_k_idx = int(k_idx % (self.knn_dim // 2))
            k = self.knn_option[t_k_idx]
            target_data = self.x_train[data_idx_list[i], :]
            target_label = self.y_train[data_idx_list[i], :]
            if distance_type == 0:
                out_neigh = self.knn_model_X.kneighbors(target_data.reshape(-1, self.data_dim), n_neighbors = k+1)
                neigh_list = out_neigh[1].reshape(-1)
                for j in range(1, len(neigh_list)):
                    mixup_idx.append((data_idx_list[i], neigh_list[j]))
            else:
                out_neigh = self.knn_model_Y.kneighbors(target_label.reshape(-1, self.label_dim), n_neighbors = k+1)
                neigh_list = out_neigh[1].reshape(-1)
                for j in range(1, len(neigh_list)):
                    mixup_idx.append((data_idx_list[i], neigh_list[j]))

        if mixup_idx == []:
            return []
        
        mixup_idx = np.array(mixup_idx)
        return mixup_idx
    
    def get_valid_perf(self, mixup_idx):
        """A function for calculating performances of the regression model trained with training data and mixed data on the validation set.
        
        Args:
            mixup_idx: Pairs of indices of data to mix.
            
              
        Returns:
            R squared and MSE of the regression model on the validation set.
            
        """

        model = copy.deepcopy(self.reg_model)
        trained_reg_model = self.train_reg_model(model, 1, mixup_idx)
        r_s, mse = self.cal_measures(trained_reg_model, self.x_valid, self.y_valid)
        return r_s, mse
    
    def get_no_aug_perf(self):
        """A function for calculating performances of the regression model trained with only existing training data on the validation set.
        
        Returns:
            R squared and MSE on the validation set by the regression model trained on only training data.
            
        """
        model = copy.deepcopy(self.reg_model)
        trained_reg_model = self.train_reg_model(model, 0, [])
        r_s, mse = self.cal_measures(trained_reg_model, self.x_valid, self.y_valid)
        return r_s, mse
        
    def to_onehot(self, integers, dim):
        """A function for converting integers to a one-hot vector format.
        
        Args:
            integers: A list of integers to be transformed.
            dim: Dimension of one-hot vector.
            
              
        Returns:
            A list of one-hot vectors tranformed from integers.
            
        """
        
        len_integers = len(integers)
        for i in range(len_integers):
            if i == 0:
                onehot = np.zeros(dim)
                onehot[integers[i]] = 1
            else:
                t_onehot = np.zeros(dim)
                t_onehot[integers[i]] = 1
                onehot = np.vstack((onehot, t_onehot))
        return onehot.reshape(-1, dim)

    def train_MixRL(self):
        """A function for the training Mixup value network."""
        
        W = 20
        Base = 0
        batch_size = self.batch_size_MixRL
        _, no_aug_mse = self.get_no_aug_perf()
        epsilon = 1e-8  # Adds to the log to avoid overflow
        counter = np.zeros((batch_size, self.knn_dim))
        for n_epi in range(self.epoch_MixRL):
            
            # Sample a batch of features, labels, and knn options
            k_idx_list = []
            for i in range(batch_size):
                sampled_idx = np.random.choice(np.where(counter[i]==0)[0], 1)[0]
                k_idx_list.append(sampled_idx)
                counter[i, sampled_idx] +=1
                if sum(counter[i, :]) == self.knn_dim:
                    counter[i, :] = 0
            k_idx_list = np.array(k_idx_list)
            k_onehot = self.to_onehot(k_idx_list, self.knn_dim)
            MixRL_input = np.hstack((self.x_train, self.y_train_norm))
            MixRL_input = np.hstack((MixRL_input, k_onehot))
            MixRL_input_tensor = torch.tensor(MixRL_input).float().to(self.device)
            
            # Calculate selection probability h_theta(x, y, k) by Mixup value network h_theta
            selection_prob = self.Mixup_value_net(MixRL_input_tensor)
            
            # Sample (x, y, k) following probabilities and calculate log probability and mean of probability for Mixup value network update.
            log_prob = 0
            sampling_list = []
            prob_mean = torch.mean(selection_prob, dim = 0)
            iters = len(selection_prob)
            for i in range(iters):
                prob = selection_prob[i, :]
                if torch.rand(1) < prob.item():
                    log_prob += torch.log(prob + epsilon)
                    sampling_list.append(i)
                else:
                    log_prob += torch.log(1-prob + epsilon)
            if not sampling_list:
                continue
            sampling_list = np.array(sampling_list)
            
            # Generate mixed data using Mixup between k nearest neighbors following knn options
            mixup_idx = self.mixup_knn(sampling_list, k_idx_list[sampling_list])
            
            # Calculate Loss
            _, Loss = self.get_valid_perf(mixup_idx)
            Loss = Loss - no_aug_mse # Adding no_augment result can improve the stability of learning.
            reward = (Loss - Base) * self.reward_scale_factor
            self.Mixup_value_net.put_data((reward, log_prob, prob_mean))
            
            # Update Mixup value network
            self.Mixup_value_net.train_net()
            
            # Update Base
            Base = (W-1) / W * Base + Loss / W

    def get_h_values(self):
        """A function for calculating h_theta(x, y, k) values for all (x, y, k) pairs using the Mixup value network.
              
        Returns:
            h_theta(x, y, k) values for all pairs.
            
        """
        
        for i in range(len(self.x_train)):
            if i == 0:
                k_onehot = self.to_onehot(np.arange(self.knn_dim), self.knn_dim).reshape(-1, self.knn_dim)
                inputs = np.hstack((self.x_train[i, :], self.y_train_norm[i, :])).reshape(-1, self.data_dim+self.label_dim)
                inputs = np.repeat(inputs, self.knn_dim, axis=0)
                inputs = np.hstack((inputs, k_onehot))
            else: 
                k_onehot = self.to_onehot(np.arange(self.knn_dim), self.knn_dim).reshape(-1, self.knn_dim)
                temp_inputs = np.hstack((self.x_train[i, :], self.y_train_norm[i, :])).reshape(-1, self.data_dim+self.label_dim)
                temp_inputs = np.repeat(temp_inputs, self.knn_dim, axis=0)
                temp_inputs = np.hstack((temp_inputs, k_onehot))
                inputs = np.vstack((inputs, temp_inputs))

        inputs = torch.tensor(inputs).float().to(self.device)
        h_values = self.Mixup_value_net(inputs).cpu().data.numpy().reshape(-1)
        return h_values
        
    def mixup_by_threshold(self, sorted_value_idx_list, threshold):
        """A function for generating mixed data using the highest-scoring (x, y, k) triples
        according to h_theta that are within a rank specified by a threshold.
        
        Args:
            sorted_value_idx_list: A list that contains the index of h values sorted by descending order.
            threshold: The number of checked (x, y, k) pairs from the highest valued one.
        
        Example:
              If threshold = 3 and
              (x1, y1, 1nn) = 0.6
              (x2, y2, 4nn) = 0.58
              (x2, y2, 1nn) = 0.52
              --------------- threshold
              (x3, y3, 2nn) = 0.5
              
              Then, (x1, y1) will be mixed with 1nn of (x1, y1), and (x2, y2) will be mixed with 4nn of (x2, y2).
              
        Returns:
            Pairs of indices of data to mix.
            
        """
        
        data_idx_list= []
        k_idx_list = []
        count_list_X = np.zeros(len(self.x_train))
        count_list_Y = np.zeros(len(self.x_train))
        for i in range(threshold):
            idx = sorted_value_idx_list[i]
            data_idx = idx // self.knn_dim
            k_idx = idx % self.knn_dim
            if k_idx // (self.knn_dim // 2) == 0:
                if count_list_X[data_idx] == 0:
                    data_idx_list.append(data_idx)
                    k_idx_list.append(k_idx)
                    count_list_X[data_idx] += 1
            else:
                if count_list_Y[data_idx] == 0:
                    data_idx_list.append(data_idx)
                    k_idx_list.append(k_idx)
                    count_list_Y[data_idx] += 1
                    
        data_idx_list = np.array(data_idx_list)
        k_idx_list = np.array(k_idx_list)
        mixup_idx = self.mixup_knn(data_idx_list, k_idx_list)
        return mixup_idx
        
    def find_best_threshold(self, sorted_value_idx_list, init_threshold, search_step):
        """A function for finding the threshold that minimizes the loss on the validation set using grid search.
        
        Args:
            sorted_value_idx_list: A list that contains the index of h values sorted in descending order.
            init_threshold: An initial value of the threshold.
            search_step: A step size for searching.
              
        Returns:
            The threshold that minimizes the loss on the validation set.
            
        """
        full_size = len(self.x_train) * self.knn_dim
        best_threshold = full_size
        best_mse = 1e3
        for threshold in range(init_threshold, full_size+1, search_step):
            num_iter = 3
            mse_list = []
            mixup_idx = self.mixup_by_threshold(sorted_value_idx_list, threshold)
            for i in range(num_iter):
                model = copy.deepcopy(self.reg_model)
                trained_reg_model = self.train_reg_model(model, 1, mixup_idx)
                r_s, mse = self.cal_measures(trained_reg_model, self.x_valid, self.y_valid)
                mse_list.append(mse)
            mean_mse = np.mean(mse_list)
            if best_mse > mean_mse:
                best_mse = mean_mse
                best_threshold = threshold
        return best_threshold