import numpy as np
import pandas as pd
import re
import os


from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.cluster import KMeans
import joblib
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, ExtraTreesClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn import tree
import pickle


from sklearn.tree import DecisionTreeClassifier
import numpy as np


from arguments.args import get_args
import argparse



args = get_args()
import random

# Set the random seed for reproducibility
random_seed = args.random_seed  # You can choose any seed value
np.random.seed(random_seed)
random.seed(random_seed)
max_depth = args.max_depth_baseline




#action trajectories for training other decision tree baselines
a1_file = ('./outputs/data/Action1Table_3_agent_maze_6_6_run1_500_3.csv')
agent1_atable = np.array(pd.read_csv(a1_file))

a2_file = ('./outputs/data/Action2Table_3_agent_maze_6_6_run1_500_3.csv')
agent2_atable = np.array(pd.read_csv(a2_file))

a3_file = ('./outputs/data/Action3Table_3_agent_maze_6_6_run1_500_3.csv')
agent3_atable = np.array(pd.read_csv(a3_file))







def convert_state(str_state):
    values_str = re.findall(r'\d+\.\d+', str_state)
    array = [float(item) for item in values_str]
    return array

def max_action(agent1_atable):
    top_frequency_list = []
    for i in range(0, len(agent1_atable)):
        row = agent1_atable[i]
        action_frequency_list = row[1:]
        top_frequency_list.append(action_frequency_list.argmax())
    return np.array(top_frequency_list)


'''# For agent 1
state_list_action_agent_1=[]
for i in range(0, len(agent1_atable)):
    row = agent1_atable[i]
    state = convert_state(row[0])
    state_list_action_agent_1.append(state)

X_action_agent_1 = np.array(state_list_action_agent_1)
y_action_agent_1 = max_action(agent1_atable)


# For agent 2
state_list_action_agent_2=[]
for i in range(0, len(agent1_atable)):
    row = agent1_atable[i]
    state = convert_state(row[0])
    state_list_action_agent_2.append(state)

X_action_agent_2 = np.array(state_list_action_agent_2)
y_action_agent_2 = max_action(agent2_atable)'''

def process_agent_data(agent_table):
    state_list = []
    for row in agent_table:
        state = convert_state(row[0])
        state_list.append(state)

    X_action = np.array(state_list)
    y_action = max_action(agent_table)

    return X_action, y_action

X_action_agent_1, y_action_agent_1 = process_agent_data(agent1_atable)
X_action_agent_2, y_action_agent_2 = process_agent_data(agent2_atable)
X_action_agent_3, y_action_agent_3 = process_agent_data(agent3_atable)



# Split data into train and test sets
X_train_action_agent_1, X_test_action_agent_1, y_train_action_agent_1, y_test_action_agent_1 = train_test_split(X_action_agent_1, y_action_agent_1, test_size=0.2, random_state=42)
X_train_action_agent_2, X_test_action_agent_2, y_train_action_agent_2, y_test_action_agent_2 = train_test_split(X_action_agent_2, y_action_agent_2, test_size=0.2, random_state=42)
X_train_action_agent_3, X_test_action_agent_3, y_train_action_agent_3, y_test_action_agent_3 = train_test_split(X_action_agent_3, y_action_agent_3, test_size=0.2, random_state=42)



# Train CART decision tree
tree_traditional_agent_1 = DecisionTreeClassifier(max_depth=max_depth)
tree_traditional_agent_2 = DecisionTreeClassifier(max_depth=max_depth)
tree_traditional_agent_3 = DecisionTreeClassifier(max_depth=max_depth)


tree_traditional_agent_1.fit(X_train_action_agent_1, y_train_action_agent_1)
tree_traditional_agent_2.fit(X_train_action_agent_2, y_train_action_agent_2)
tree_traditional_agent_3.fit(X_train_action_agent_3, y_train_action_agent_3)


# Train Random Forest
random_forest_agent_1 = RandomForestClassifier(max_depth=max_depth)
random_forest_agent_2 = RandomForestClassifier(max_depth=max_depth)
random_forest_agent_3 = RandomForestClassifier(max_depth=max_depth)


random_forest_agent_1.fit(X_train_action_agent_1, y_train_action_agent_1)
random_forest_agent_2.fit(X_train_action_agent_2, y_train_action_agent_2)
random_forest_agent_3.fit(X_train_action_agent_3, y_train_action_agent_3)


# Train Gradient Boosting Decision Trees
gbdt_agent_1 = GradientBoostingClassifier(max_depth=max_depth)
gbdt_agent_1.fit(X_train_action_agent_1, y_train_action_agent_1)
gbdt_agent_2 = GradientBoostingClassifier(max_depth=max_depth)
gbdt_agent_2.fit(X_train_action_agent_2, y_train_action_agent_2)
gbdt_agent_3 = GradientBoostingClassifier(max_depth=max_depth)
gbdt_agent_3.fit(X_train_action_agent_3, y_train_action_agent_3)

# Train Extra Trees
extra_trees_agent_1 = ExtraTreesClassifier(max_depth=max_depth)
extra_trees_agent_1.fit(X_train_action_agent_1, y_train_action_agent_1)
extra_trees_agent_2 = ExtraTreesClassifier(max_depth=max_depth)
extra_trees_agent_2.fit(X_train_action_agent_2, y_train_action_agent_2)
extra_trees_agent_3 = ExtraTreesClassifier(max_depth=max_depth)
extra_trees_agent_3.fit(X_train_action_agent_3, y_train_action_agent_3)



#For agent 1
# Predictions
predictions_traditional_agent_1 = tree_traditional_agent_1.predict(X_test_action_agent_1)
predictions_rf_agent_1 = random_forest_agent_1.predict(X_test_action_agent_1)
predictions_gbdt_agent_1 = gbdt_agent_1.predict(X_test_action_agent_1)
predictions_et_agent_1 = extra_trees_agent_1.predict(X_test_action_agent_1)

# Evaluate accuracy
accuracy_traditional_agent_1 = accuracy_score(y_test_action_agent_1, predictions_traditional_agent_1)
accuracy_rf_agent_1 = accuracy_score(y_test_action_agent_1, predictions_rf_agent_1)
accuracy_gbdt_agent_1 = accuracy_score(y_test_action_agent_1, predictions_gbdt_agent_1)
accuracy_et_agent_1 = accuracy_score(y_test_action_agent_1, predictions_et_agent_1)

print("Accuracy of Traditional Decision Tree for Agent 1:", accuracy_traditional_agent_1)
print("Accuracy of Random Forest for Agent 1:", accuracy_rf_agent_1)
print("Accuracy of Gradient Boosting Decision Trees for Agent 1:", accuracy_gbdt_agent_1)
print("Accuracy of Extra Trees for Agent 1:", accuracy_et_agent_1)


#For agent 2
# Predictions
predictions_traditional_agent_2 = tree_traditional_agent_2.predict(X_test_action_agent_2)
predictions_rf_agent_2 = random_forest_agent_2.predict(X_test_action_agent_2)
predictions_gbdt_agent_2 = gbdt_agent_2.predict(X_test_action_agent_2)
predictions_et_agent_2 = extra_trees_agent_2.predict(X_test_action_agent_2)

# Evaluate accuracy
accuracy_traditional_agent_2 = accuracy_score(y_test_action_agent_2, predictions_traditional_agent_2)
accuracy_rf_agent_2 = accuracy_score(y_test_action_agent_2, predictions_rf_agent_2)
accuracy_gbdt_agent_2 = accuracy_score(y_test_action_agent_2, predictions_gbdt_agent_2)
accuracy_et_agent_2 = accuracy_score(y_test_action_agent_2, predictions_et_agent_2)

print("Accuracy of Traditional Decision Tree for Agent 2:", accuracy_traditional_agent_2)
print("Accuracy of Random Forest for Agent 2:", accuracy_rf_agent_2)
print("Accuracy of Gradient Boosting Decision Trees for Agent 2:", accuracy_gbdt_agent_2)
print("Accuracy of Extra Trees for Agent 2:", accuracy_et_agent_2)

#For agent 3
# Predictions
predictions_traditional_agent_3 = tree_traditional_agent_3.predict(X_test_action_agent_3)
predictions_rf_agent_3 = random_forest_agent_3.predict(X_test_action_agent_3)
predictions_gbdt_agent_3 = gbdt_agent_3.predict(X_test_action_agent_3)
predictions_et_agent_3 = extra_trees_agent_3.predict(X_test_action_agent_3)

# Evaluate accuracy
accuracy_traditional_agent_3 = accuracy_score(y_test_action_agent_3, predictions_traditional_agent_3)
accuracy_rf_agent_3 = accuracy_score(y_test_action_agent_3, predictions_rf_agent_3)
accuracy_gbdt_agent_3 = accuracy_score(y_test_action_agent_3, predictions_gbdt_agent_3)
accuracy_et_agent_3 = accuracy_score(y_test_action_agent_3, predictions_et_agent_3)

print("Accuracy of Traditional Decision Tree for Agent 2:", accuracy_traditional_agent_3)
print("Accuracy of Random Forest for Agent 2:", accuracy_rf_agent_3)
print("Accuracy of Gradient Boosting Decision Trees for Agent 2:", accuracy_gbdt_agent_3)
print("Accuracy of Extra Trees for Agent 2:", accuracy_et_agent_3)







from joblib import dump, load

# Construct the directory path including the scenario name
model_dir = os.path.join('outputs','Step2_BaselineDTModels',str(args.max_depth_baseline))
# Create the directory if it doesn't exist
os.makedirs(model_dir, exist_ok=True)

#For agent 1, save models
# Save the models with appropriate filenames
dump(tree_traditional_agent_1, os.path.join(model_dir, f'tree_traditional_agent_1_MaxDepth_{max_depth}.joblib'.format(max_depth=args.max_depth_baseline)))
dump(random_forest_agent_1, os.path.join(model_dir, f'random_forest_agent_1_MaxDepth_{max_depth}.joblib'.format(max_depth=args.max_depth_baseline)))
dump(gbdt_agent_1, os.path.join(model_dir, f'gbdt_agent_1_MaxDepth_{max_depth}.joblib'.format(max_depth=args.max_depth_baseline)))
dump(extra_trees_agent_1, os.path.join(model_dir, f'extra_trees_agent_1_MaxDepth_{max_depth}.joblib'.format(max_depth=args.max_depth_baseline)))

#For agent 2, save models
# Save the models with appropriate filenames
dump(tree_traditional_agent_2, os.path.join(model_dir, f'tree_traditional_agent_2_MaxDepth_{max_depth}.joblib'.format(max_depth=args.max_depth_baseline)))
dump(random_forest_agent_2, os.path.join(model_dir, f'random_forest_agent_2_MaxDepth_{max_depth}.joblib'.format(max_depth=args.max_depth_baseline)))
dump(gbdt_agent_2, os.path.join(model_dir, f'gbdt_agent_2_MaxDepth_{max_depth}.joblib'.format(max_depth=args.max_depth_baseline)))
dump(extra_trees_agent_2, os.path.join(model_dir, f'extra_trees_agent_2_MaxDepth_{max_depth}.joblib'.format(max_depth=args.max_depth_baseline)))


#For agent 2, save models
# Save the models with appropriate filenames
dump(tree_traditional_agent_3, os.path.join(model_dir, f'tree_traditional_agent_3_MaxDepth_{max_depth}.joblib'.format(max_depth=args.max_depth_baseline)))
dump(random_forest_agent_3, os.path.join(model_dir, f'random_forest_agent_3_MaxDepth_{max_depth}.joblib'.format(max_depth=args.max_depth_baseline)))
dump(gbdt_agent_3, os.path.join(model_dir, f'gbdt_agent_3_MaxDepth_{max_depth}.joblib'.format(max_depth=args.max_depth_baseline)))
dump(extra_trees_agent_3, os.path.join(model_dir, f'extra_trees_agent_3_MaxDepth_{max_depth}.joblib'.format(max_depth=args.max_depth_baseline)))









