import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, ExtraTreesClassifier

from sklearn.cluster import KMeans
from sklearn_extra.cluster import KMedoids


from sklearn.metrics import accuracy_score
from joblib import dump, load
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import re

from arguments.args import get_args
import argparse
import os

args = get_args()
max_depth_each_level = 1
num_clusters = 4
max_depth_RGMDT = args.max_depth_RGMDT


def cluster_with_cosine_distance(data, n_clusters):
    # Using KMedoids with cosine metric
    clustering_model = KMedoids(n_clusters=n_clusters, metric='cosine', random_state=0)
    clustering_model.fit(data)
    labels = clustering_model.labels_
    return labels, clustering_model

def save_model(model, filename):
    dump(model, filename)


def load_model(filename):
    return load(filename)


def evaluate_model(model, X_test, y_test):
    predictions = model.predict(X_test)
    accuracy = accuracy_score(y_test, predictions)
    report = classification_report(y_test, predictions, zero_division=0)
    return accuracy, report


def load_data(file_path):
    return np.array(pd.read_csv(file_path))


def convert_label_to_dict(table):
    dict = {}
    for row in table:
        dict[row[0]] = np.argmax(row[1:5]) + 1
    return dict


#def convert_string_obs_to_array(col):
#    x = float(col[8:14])
#    y = float(col[21:28])
#    return [x, y]
def convert_string_obs_to_array(str_array):
    # Ensure we are removing all brackets and any leading/trailing whitespace characters
    # This also replaces multiple spaces or tabs with a single space for splitting
    import re
    content = re.sub(r'[^\d.,]+', '', str_array)  # This removes everything but digits, dots, and commas
    array = np.array([float(x.strip()) for x in content.split(',') if x.strip() != ''])
    return array




#def convert_string_q_vector_to_value(col):
#    result = []
#    for string_item in col:
#        x = float(string_item[1: len(string_item) - 1])
#        result.append(x)
#    return result

def convert_string_q_vector_to_value(col):
    result = []
    for string_item in col:
        # Extract all numbers from the string
        numbers = re.findall(r'\d+\.\d+', string_item)
        # Convert each found number to float and extend the result list
        result.extend([float(num) for num in numbers])
    return result



def convert_state(str_state):
    values_str = re.findall(r'\d+\.\d+', str_state)
    array = [float(item) for item in values_str]
    return array


def preprocess_data_with_RL_policy(agent2_qtable_central_labeled, agent2_qtable_independent):
    # Implement preprocessing like extraction of features and labels
    unique_agent2_obs = set(agent2_qtable_central_labeled[:, 1])

    obs_list_agent_2 = []
    q_vector_list_agent_2 = []
    for unique_obs in unique_agent2_obs:
        predict_a_2 = agent2_qtable_independent[unique_obs]
        sub_table_agent_2 = agent2_qtable_central_labeled[np.where(
            (agent2_qtable_central_labeled[:, 1] == unique_obs) * (agent2_qtable_central_labeled[:, 2] == predict_a_2))]
        obs_list_agent_2.append(convert_string_obs_to_array(unique_obs))
        q_vector_list_agent_2.append(convert_string_q_vector_to_value(sub_table_agent_2[:, 5]))

    pd_q_vector_agent_2 = pd.DataFrame(q_vector_list_agent_2)
    pd_q_vector_agent_2 = pd_q_vector_agent_2.fillna(0.0)
    if pd_q_vector_agent_2.empty:
        raise ValueError("The input data frame is empty.")
    # Using KMedoids with cosine metric
    Clustering_Q_Vectors_agent_2 = KMedoids(n_clusters=num_clusters, metric='cosine', random_state=0)
    Clustering_Q_Vectors_agent_2.fit(pd_q_vector_agent_2)
    #Clustering_Q_Vectors_agent_2 = KMeans(n_clusters=num_clusters, random_state=0, n_init=10).fit(pd_q_vector_agent_2)
    label_agent_2 = Clustering_Q_Vectors_agent_2.labels_
    return np.array(obs_list_agent_2), label_agent_2


def predict_update_data_with_DT_policy(agent2_qtable_central_labeled, DT_agent_2_level_0):
    # Implement preprocessing like extraction of features and labels
    unique_agent2_obs = set(agent2_qtable_central_labeled[:, 1])
    obs_list_agent_2 = []
    q_vector_list_agent_2 = []
    for unique_obs in unique_agent2_obs:
        unique_obs_array = np.array(convert_state(unique_obs))
        unique_obs_array_reshape = unique_obs_array.reshape(1, -1)
        predict_a_2 = DT_agent_2_level_0.predict(unique_obs_array_reshape)
        sub_table_agent_2 = agent2_qtable_central_labeled[np.where(
            (agent2_qtable_central_labeled[:, 1] == unique_obs) * (agent2_qtable_central_labeled[:, 2] == predict_a_2))]
        obs_list_agent_2.append(convert_state(unique_obs))
        q_vector_list_agent_2.append(convert_string_q_vector_to_value(sub_table_agent_2[:, 5]))

    pd_q_vector_agent_2 = pd.DataFrame(q_vector_list_agent_2)
    pd_q_vector_agent_2 = pd_q_vector_agent_2.fillna(0.0)
    # Using KMedoids with cosine metric
    Clustering_Q_Vectors_agent_2 = KMedoids(n_clusters=num_clusters, metric='cosine', random_state=0)
    Clustering_Q_Vectors_agent_2.fit(pd_q_vector_agent_2)
    #Clustering_Q_Vectors_agent_2 = KMeans(n_clusters=num_clusters, random_state=0, n_init=10).fit(pd_q_vector_agent_2)
    label_agent_2 = Clustering_Q_Vectors_agent_2.labels_
    return np.array(obs_list_agent_2), label_agent_2


def train_decision_tree(X, y, depth):
    model = GradientBoostingClassifier(max_depth=depth)
    model.fit(X, y)
    return model

def convert_state(str_state):
    values_str = re.findall(r'\d+\.\d+', str_state)
    array = [float(item) for item in values_str]
    return array
# Getting actions
# action trajectories for training other decision tree baselines
a1_file = ('./outputs/data/ExpertRL_1_Table_3_agent_maze_6_6_run1_10000_6.csv')
agent1_atable = np.array(pd.read_csv(a1_file))

a2_file = ('./outputs/data/ExpertRL_2_Table_3_agent_maze_6_6_run1_10000_6.csv')
agent2_atable = np.array(pd.read_csv(a2_file))

a3_file = ('./outputs/data/ExpertRL_3_Table_3_agent_maze_6_6_run1_10000_6.csv')
agent3_atable = np.array(pd.read_csv(a3_file))

def max_action(agent1_atable):
    top_frequency_list = []
    for i in range(0, len(agent1_atable)):
        row = agent1_atable[i]
        action_frequency_list = row[1:]
        top_frequency_list.append(action_frequency_list.argmax())
    return np.array(top_frequency_list)


def process_agent_table(agent_table):
    state_list = []
    for row in agent_table:
        state = convert_state(row[0])
        state_list.append(state)

    X_action = np.array(state_list)
    y_action = max_action(agent_table)

    return X_action, y_action





def main():
    # Load and preprocess data
    agent1_qtable_independent = convert_label_to_dict(load_data("./outputs/data/Q1Table_3_agent_maze_6_6_run1_10000_6.csv"))
    agent2_qtable_independent = convert_label_to_dict(load_data("./outputs/data/Q2Table_3_agent_maze_6_6_run1_10000_6.csv"))
    agent3_qtable_independent = convert_label_to_dict(load_data("./outputs/data/Q3Table_3_agent_maze_6_6_run1_10000_6.csv"))


    agent1_qtable_central_labeled = load_data("../Get_Q_Vectors/baseline1_central_agent_1_q_1000_6_test.csv")
    agent2_qtable_central_labeled = load_data("../Get_Q_Vectors/baseline1_central_agent_2_q_1000_6_test.csv")
    agent3_qtable_central_labeled = load_data("../Get_Q_Vectors/baseline1_central_agent_3_q_1000_6_test.csv")



    accuracies_agent1 = []
    accuracies_agent2 = []
    accuracies_agent3 = []


    # For the initial interation, getting the label from RL independent q table for each agent
    X_agent2_initial, y_agent2_initial = preprocess_data_with_RL_policy(agent2_qtable_central_labeled,
                                                                        agent2_qtable_independent)
    X_agent1_initial, y_agent1_initial = preprocess_data_with_RL_policy(agent1_qtable_central_labeled,
                                                                        agent1_qtable_independent)
    X_agent3_initial, y_agent3_initial = preprocess_data_with_RL_policy(agent3_qtable_central_labeled,
                                                                        agent3_qtable_independent)

    # Split data for Agent 1
    X_train_agent1_initial, X_test_agent1_initial, y_train_agent1_initial, y_test_agent1_initial = train_test_split(
        X_agent1_initial, y_agent1_initial, test_size=0.2, random_state=42)

    # Split data for Agent 2
    X_train_agent2_initial, X_test_agent2_initial, y_train_agent2_initial, y_test_agent2_initial = train_test_split(
        X_agent2_initial, y_agent2_initial, test_size=0.2, random_state=42)

    # Split data for Agent 3
    X_train_agent3_initial, X_test_agent3_initial, y_train_agent3_initial, y_test_agent3_initial = train_test_split(
        X_agent3_initial, y_agent3_initial, test_size=0.2, random_state=42)

    # Initial model training
    DT_agent_1_level_0 = train_decision_tree(X_agent1_initial, y_agent1_initial, max_depth_each_level)
    DT_agent_2_level_0 = train_decision_tree(X_agent2_initial, y_agent2_initial, max_depth_each_level)
    DT_agent_3_level_0 = train_decision_tree(X_agent3_initial, y_agent3_initial, max_depth_each_level)

    X_action_agent_1, y_action_agent_1 = process_agent_table(agent1_atable)
    X_action_agent_2, y_action_agent_2 = process_agent_table(agent2_atable)
    X_action_agent_3, y_action_agent_3 = process_agent_table(agent3_atable)

    # Evaluate and print results
    acc1, rep1 = evaluate_model(DT_agent_1_level_0, X_action_agent_1, y_action_agent_1)
    acc2, rep2 = evaluate_model(DT_agent_2_level_0, X_action_agent_2, y_action_agent_2)
    acc3, rep3 = evaluate_model(DT_agent_3_level_0, X_action_agent_3, y_action_agent_3)

    accuracies_agent1.append(acc1)
    accuracies_agent2.append(acc2)
    accuracies_agent3.append(acc3)


    print(f'Iteration 0 - Accuracy DT_agent_1: {acc1}, DT_agent_2: {acc2}, DT_agent_3: {acc3}')

    print(f'Classification Report DT_agent_1:\n{rep1}')
    print(f'Classification Report DT_agent_2:\n{rep2}')
    print(f'Classification Report DT_agent_2:\n{rep3}')


    # Construct the directory path including the scenario name
    model_dir = os.path.join('outputs', 'Step2_RGMDTModels',
                             'MaxDepth_{max_depth_RGMDT}'.format(max_depth_RGMDT=args.max_depth_RGMDT))
    # Create the directory if it doesn't exist
    os.makedirs(model_dir, exist_ok=True)
    dump(DT_agent_1_level_0, os.path.join(model_dir,
                                          f'DT_agent_1_level_0.joblib'))
    dump(DT_agent_2_level_0, os.path.join(model_dir,
                                          f'DT_agent_2_level_0.joblib'))
    dump(DT_agent_3_level_0, os.path.join(model_dir,
                                          f'DT_agent_3_level_0.joblib'))

    # save_model(DT_agent_1_level_0, './Step2_RGMDTModels/DT_agent_1_level_0.joblib')
    # save_model(DT_agent_2_level_0, './Step2_RGMDTModels/DT_agent_2_level_0.joblib')

    # Iterative update
    for iteration in range(max_depth_RGMDT):
        X_agent2, y_agent2 = predict_update_data_with_DT_policy(agent2_qtable_central_labeled, DT_agent_2_level_0)
        X_agent1, y_agent1 = predict_update_data_with_DT_policy(agent1_qtable_central_labeled, DT_agent_1_level_0)
        X_agent3, y_agent3 = predict_update_data_with_DT_policy(agent3_qtable_central_labeled, DT_agent_3_level_0)

        # Split data for Agent 1
        X_train_agent1, X_test_agent1, y_train_agent1, y_test_agent1 = train_test_split(
            X_agent1, y_agent1, test_size=0.2, random_state=42)

        # Split data for Agent 2
        X_train_agent2, X_test_agent2, y_train_agent2, y_test_agent2 = train_test_split(
            X_agent2, y_agent2, test_size=0.2, random_state=42)
        # Split data for Agent 3
        X_train_agent3, X_test_agent3, y_train_agent3, y_test_agent3 = train_test_split(
            X_agent3, y_agent3, test_size=0.2, random_state=42)

        DT_agent_1_level_1 = GradientBoostingClassifier(max_depth=iteration + 1)
        DT_agent_1_level_1 = DT_agent_1_level_1.fit(X_action_agent_1, y_action_agent_1)
        DT_agent_2_level_1 = GradientBoostingClassifier(max_depth=iteration + 1)
        DT_agent_2_level_1 = DT_agent_2_level_1.fit(X_action_agent_2, y_action_agent_2)
        DT_agent_3_level_1 = GradientBoostingClassifier(max_depth=iteration + 1)
        DT_agent_3_level_1 = DT_agent_3_level_1.fit(X_action_agent_3, y_action_agent_3)
        # Save
        dump(DT_agent_1_level_1, os.path.join(model_dir,
                                              f'DT_agent_1_level_{iteration + 1}.joblib'))
        dump(DT_agent_2_level_1, os.path.join(model_dir,
                                              f'DT_agent_2_level_{iteration + 1}.joblib'))
        dump(DT_agent_3_level_1, os.path.join(model_dir,
                                              f'DT_agent_3_level_{iteration + 1}.joblib'))
        # save_model(DT_agent_1_level_1, f'./Step2_RGMDTModels/DT_agent_1_level_{iteration + 1}.joblib')
        # save_model(DT_agent_2_level_1, f'./Step2_RGMDTModels/DT_agent_2_level_{iteration + 1}.joblib')

        # Evaluate and print results
        acc1, rep1 = evaluate_model(DT_agent_1_level_1, X_action_agent_1, y_action_agent_1)
        acc2, rep2 = evaluate_model(DT_agent_2_level_1, X_action_agent_2, y_action_agent_2)
        acc3, rep3 = evaluate_model(DT_agent_3_level_1, X_action_agent_3, y_action_agent_3)

        accuracies_agent1.append(acc1)
        accuracies_agent2.append(acc2)
        accuracies_agent3.append(acc3)


        print(f'Iteration {iteration + 1} - Accuracy DT_agent_1: {acc1}, DT_agent_2: {acc2},DT_agent_3: {acc3}')
        print(f'Classification Report DT_agent_1:\n{rep1}')
        print(f'Classification Report DT_agent_2:\n{rep2}')
        print(f'Classification Report DT_agent_3:\n{rep3}')

        # Plotting the results
    plt.figure(figsize=(12, 9))
    plt.plot(accuracies_agent1, label='Agent 1 Accuracy')
    plt.plot(accuracies_agent2, label='Agent 2 Accuracy')
    plt.plot(accuracies_agent3, label='Agent 3 Accuracy')

    plt.xlabel('Iteration')
    plt.ylabel('Accuracy')
    plt.title('Accuracy of Decision Trees Over Iterations')
    plt.legend()
    # plt.grid(True)
    # Construct the directory path including the scenario name
    fig_dir = os.path.join('outputs', 'Step2_RGMDT_Prediction_Acc_Figs',
                           'MaxDepth_{max_depth_RGMDT}'.format(max_depth_RGMDT=args.max_depth_RGMDT))
    # Create the directory if it doesn't exist
    os.makedirs(fig_dir, exist_ok=True)
    fig_file = fig_dir + "/DT_agent_1_2_3_accuracy.pdf"
    plt.savefig(fig_file, format='pdf')
    plt.show()


if __name__ == "__main__":
    main()
