import sys
import pandas as pd
import numpy as np
import os
import warnings
import random
from scipy.stats import pearsonr, spearmanr
from utils import get_random_actions, get_optimal_actions, sigmoid, run_experiment, get_starting_data
from kernels import L1Norm
from models.svr import SVM
from models.kr import KR
from models.gp import GP

warnings.simplefilter(action='ignore', category=FutureWarning)

# Command line arguments: MODEL, NUM_POSSIBLE_ACTIONS, EXPERIMENT_NAME, PRINT_EVERY, N_RUNS

# Define parameters for this run
MIN_REWARD = -3
MAX_REWARD = 3
# Possible models: svr, kr, gp
MODEL = sys.argv[1]
NUM_POSSIBLE_ACTIONS = 50
EXPERIMENT_NAME = sys.argv[2]  # pick informative name; output will be written to this folder
N_RUNS = int(sys.argv[3])
# Define the amount of corruption to iterate over
n_corrupted = [int(i) for i in range(0, 50, 5)]

# Optional command line argument specifying what to iterate over
if len(sys.argv) > 4:
    n_corrupted = [int(i) for i in sys.argv[4].split(", ")]

# Return "distance" between a and b as defined by a kernel function
kernel_fn = L1Norm  # default kernel function

# List of metrics to track
metrics = ['mean_reward', 'unique_actions_taken', 'non_optimal_actions_taken', 'negative_actions_taken',
           'iterations_to_convergence']

base_path = './data/' + EXPERIMENT_NAME
if not os.path.exists(base_path):
    os.mkdir(base_path)

# Compute ground-truth similarity matrix without corruption.
def get_similarity_matrix(x, y, actions, scores):
    # Create similarity matrix to return.
    # Note: the inputs x, y are the things to compute similarity between, NOT features and targets!
    n_points = x.shape[0]
    similarity = np.zeros((n_points, n_points))

    # Populate each element of the similarity matrix with the appropriate value.
    for idx in range(n_points):
        action1 = np.where(actions == x[idx])[0][0]
        for j in range(n_points):
            action2 = np.where(actions == y[j])[0][0]
            similarity[idx][j] = (MAX_REWARD - MIN_REWARD) - kernel_fn(scores[action1], scores[action2])

    return similarity


def generate_random_environment():
    # Get lists of actions and associated morality scores for this simulation.
    actions, real_morality_scores = get_random_actions(NUM_POSSIBLE_ACTIONS, MAX_REWARD, MIN_REWARD)
    real_optimal_actions = get_optimal_actions(NUM_POSSIBLE_ACTIONS, actions, real_morality_scores)

    return actions, real_morality_scores, real_optimal_actions


def get_morality_corruption_setup(num_corrupted):
    # Get lists of actions and associated morality scores for this simulation.
    actions, real_morality_scores = get_random_actions(NUM_POSSIBLE_ACTIONS, MAX_REWARD, MIN_REWARD)
    real_optimal_actions = get_optimal_actions(NUM_POSSIBLE_ACTIONS, actions, real_morality_scores)

    # Get a set of corrupted morality scores to use when computing the similarity matrix.
    corrupted_morality_scores = real_morality_scores.copy()
    count_corrupted = 0
    for i in np.random.permutation(len(real_morality_scores)):
        if count_corrupted < num_corrupted:
            # Choose a new random score that is different from the original.
            corrupted_morality_scores[i] = random.choice(list(range(MIN_REWARD, real_morality_scores[i]))
                                                         + list(range(real_morality_scores[i] + 1, MAX_REWARD)))
            count_corrupted += 1

    starting_action_data, starting_score_data = get_starting_data(actions)

    # 2. Instantiate the model.
    similarity_matrix = get_similarity_matrix(actions, actions, actions, corrupted_morality_scores)
    
    if MODEL == 'kr':
        model = KR(actions, similarity_matrix)
    elif MODEL == 'gp':
        model = GP(actions, similarity_matrix, sigma=1.0)
    elif MODEL == 'svr':
        # Compute similarity matrix using custom function.
        # x, y: vectors containing features for all training data points, i.e. the label of the action taken.
        # return: a similarity matrix between every action and every other action in the training data.
        def svr_kernel(x, y):
            # Create similarity matrix to return.
            # Note: the inputs x, y are the things to compute similarity between, NOT features and targets!
            n_points = x.shape[0]
            similarity = np.zeros((n_points, n_points))
            # Populate each element of the similarity matrix with the appropriate value.
            for idx in range(n_points):
                action1 = np.where(actions == x[idx])[0][0]
                for j in range(n_points):
                    action2 = np.where(actions == y[j])[0][0]
                    similarity[idx][j] = (MAX_REWARD - MIN_REWARD) - kernel_fn(
                        corrupted_morality_scores[action1], corrupted_morality_scores[action2])
            return similarity

        model = SVM(actions, svr_kernel, starting_action_data, starting_score_data)

    return actions, starting_action_data, starting_score_data, real_morality_scores, real_optimal_actions, model, \
           corrupted_morality_scores


# Initialize some train and test actions. 
train_actions = [22, 41, 28, 8, 20, 10, 44, 1, 40, 11, 48, 21, 0, 6, 31, 19, 9, 23, 5, 49, 2, 33, 46, 13, 15]
test_actions = [29, 47, 39, 42, 16, 38, 3, 7, 17, 12, 45, 25, 37, 35, 24, 18, 43, 36, 26, 30, 27, 34, 4, 14, 32]

for nc in n_corrupted:
    print('Running morality score corruption with', nc, 'corrupted elements')
    for run in range(N_RUNS):
        actions, action_data, score_data, morality_scores, optimal_actions, model, corrupted_morality_scores = \
            get_morality_corruption_setup(nc)

        # Run experiment.
        path = base_path + "/" + MODEL + "/train/"
        filename = path + "run_" + str(run) + "_summary.txt"
        if not os.path.exists(path):
            os.makedirs(path)
        if not os.path.exists(path + str(nc) + "_corrupted/"):
            os.makedirs(path + str(nc) + "_corrupted/")

        # Compute the corrupted kernel
        randomized_kernel = get_similarity_matrix(actions, actions, actions, corrupted_morality_scores)

        # Compute Pearson and Spearman correlations
        # First, compute real similarity matrix
        similarity = get_similarity_matrix(actions, actions, actions, morality_scores)
        # Flatten only upper triangular elements of matrices (due to symmetric nature of similarity matrices)
        real_flattened = []
        corrupted_flattened = []
        real_train_only = []
        corrupted_train_only = []
        real_train_test = []
        corrupted_train_test = []
        for i in range(1, NUM_POSSIBLE_ACTIONS):
            for j in range(i + 1, NUM_POSSIBLE_ACTIONS):
                real_flattened.append(similarity[i, j])
                corrupted_flattened.append(randomized_kernel[i, j])

                if i in train_actions and j in train_actions:
                    real_train_only.append(similarity[i, j])
                    corrupted_train_only.append(randomized_kernel[i, j])

                if i in train_actions and j in test_actions:
                    real_train_test.append(similarity[i, j])
                    corrupted_train_test.append(randomized_kernel[i, j])

        # Phase 1: Training
        # Track metrics for the training process. Update models at each timestep.
        single_run_metrics = run_experiment(actions, action_data, score_data, morality_scores,
                                            optimal_actions, model, n_allowed_actions=10,
                                            immoral_threshold=50, max_iterations = 1000, 
                                            allowed_actions_list=train_actions,
                                            history_filename=path + str(nc) + "_corrupted/run_" + str(run) + ".csv",
                                            spearman=spearmanr(real_flattened, corrupted_flattened)[0],
                                            test=False)

        # Add Pearson and Spearman correlation to running total
        single_run_metrics['pearson'] = pearsonr(real_flattened, corrupted_flattened)[0]
        single_run_metrics['spearman'] = spearmanr(real_flattened, corrupted_flattened)[0]
        single_run_metrics['spearman_train_only'] = spearmanr(real_train_only, corrupted_train_only)[0]
        single_run_metrics['spearman_train_test'] = spearmanr(real_train_test, corrupted_train_test)[0]

        # Add amount of corruption
        single_run_metrics['nc'] = nc

        # Save out metrics from this individual run
        for metric in single_run_metrics.keys():
            # Put in a list so it can be parsed to a DataFrame row
            single_run_metrics[metric] = [single_run_metrics[metric]]

        additional_cols = ['spearman', 'spearman_train_only', 'spearman_train_test', 'pearson', 'nc']
        trial = pd.DataFrame(single_run_metrics, columns=metrics + additional_cols)
        if os.path.exists(path + 'run_data.csv'):
            run_metrics = pd.concat([pd.read_csv(path + 'run_data.csv'), trial])
        else:
            run_metrics = trial
        run_metrics.to_csv(path + 'run_data.csv', index=False)

        # Phase 2: Testing
        # Do not update the model in this step. It is simply evaluated on the metrics during the
        # test trials. 
        path = base_path + "/" + MODEL + "/test/"
        filename = path + "run_" + str(run) + "_summary.txt"
        if not os.path.exists(path):
            os.makedirs(path)
        if not os.path.exists(path + str(nc) + "_corrupted/"):
            os.makedirs(path + str(nc) + "_corrupted/")

        single_run_metrics = run_experiment(actions, action_data, score_data, morality_scores,
                                            optimal_actions, model, n_allowed_actions=10,
                                            immoral_threshold=50, max_iterations = 1000, 
                                            allowed_actions_list=test_actions,
                                            history_filename=path + str(nc) + "_corrupted/run_" + str(run) + ".csv",
                                            spearman=spearmanr(real_flattened, corrupted_flattened)[0],
                                            test=True)

        # Add Pearson and Spearman correlation to running total
        single_run_metrics['pearson'] = pearsonr(real_flattened, corrupted_flattened)[0]
        single_run_metrics['spearman'] = spearmanr(real_flattened, corrupted_flattened)[0]
        single_run_metrics['spearman_train_only'] = spearmanr(real_train_only, corrupted_train_only)[0]
        single_run_metrics['spearman_train_test'] = spearmanr(real_train_test, corrupted_train_test)[0]

        # Add number of interpolation steps - probably not needed but saving just in case
        single_run_metrics['nc'] = nc

        # Save out metrics from this individual run
        for metric in single_run_metrics.keys():
            # Put in a list so it can be parsed to a DataFrame row
            single_run_metrics[metric] = [single_run_metrics[metric]]

        trial = pd.DataFrame(single_run_metrics, columns=metrics + additional_cols)
        if os.path.exists(path + 'run_data.csv'):
            run_metrics = pd.concat([pd.read_csv(path + 'run_data.csv'), trial])
        else:
            run_metrics = trial
        run_metrics.to_csv(path + 'run_data.csv', index=False)
