import sys, pathlib
sys.path.append(str(pathlib.Path(__file__).parent.parent))
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from collections import defaultdict
import numpy as np
from sklearn.model_selection import train_test_split
from tools import feature_list 
import time

t = time.time()
# Load the data
# ['LSCRW', 'DataCo','GlobalStore','OAS']
dataset = 'GlobalStore'
train_data = pd.read_csv(f'./datasets/{dataset}/processed_{dataset}_train.csv')
test_data = pd.read_csv(f'./datasets/{dataset}/processed_{dataset}_test.csv')

# Encode categorical variables
encoded_train_data = train_data.copy()
encoded_test_data = test_data.copy()
label_encoders = {}

for column in encoded_train_data.select_dtypes(include=['object']).columns:
    label_encoders[column] = LabelEncoder()
    encoded_train_data[column] = label_encoders[column].fit_transform(encoded_train_data[column])
    encoded_test_data[column] = label_encoders[column].transform(encoded_test_data[column])

# Function to calculate conditional probabilities
def calculate_conditional_probs(data, target, conditions):
    conditional_probs = defaultdict(dict)
    grouped = data.groupby(list(conditions))
    for group_keys, group_data in grouped:
        target_counts = group_data[target].value_counts(normalize=True)
        for value, prob in target_counts.items():
            conditional_probs[group_keys][value] = prob
    return conditional_probs

# Calculate conditional probabilities for each target
relevant_features = encoded_train_data.drop(feature_list.label[f'{dataset}'], axis=1).columns
late_risk_probs = calculate_conditional_probs(encoded_train_data, feature_list.label[f'{dataset}'][0], relevant_features)
day_for_shipping_probs = calculate_conditional_probs(encoded_train_data, feature_list.label[f'{dataset}'][1], relevant_features.tolist() + [feature_list.label[f'{dataset}'][0]])
on_time_probs = calculate_conditional_probs(encoded_train_data, feature_list.label[f'{dataset}'][2], relevant_features.tolist() + feature_list.label[f'{dataset}'][:2])

# Simulation functions
def simulate_late_risk(conditions):
    condition_key = tuple(conditions[feature] for feature in relevant_features)
    probs = late_risk_probs.get(condition_key, {})
    if probs:
        return np.random.choice(list(probs.keys()), p=list(probs.values()))
    else:
        return np.random.choice(encoded_train_data[feature_list.label[f'{dataset}'][0]].unique())

def simulate_day_for_shipping(conditions, late_risk_value):
    condition_key = tuple(conditions[feature] for feature in relevant_features) + (late_risk_value,)
    probs = day_for_shipping_probs.get(condition_key, {})
    if probs:
        return np.random.choice(list(probs.keys()), p=list(probs.values()))
    else:
        return np.random.choice(encoded_train_data[feature_list.label[f'{dataset}'][1]].unique())

def simulate_on_time(conditions, late_risk_value, day_for_shipping_value):
    condition_key = tuple(conditions[feature] for feature in relevant_features) + (late_risk_value, day_for_shipping_value)
    probs = on_time_probs.get(condition_key, {})
    if probs:
        return np.random.choice(list(probs.keys()), p=list(probs.values()))
    else:
        return np.random.choice(encoded_train_data[feature_list.label[f'{dataset}'][2]].unique())

# Simulate for each row in the test set
simulated_results = []

str_1, str_2, str_3 = feature_list.label[f'{dataset}'][0], feature_list.label[f'{dataset}'][1], feature_list.label[f'{dataset}'][2]

for _, row in encoded_test_data.iterrows():
    conditions = row[relevant_features].to_dict()
    simulated_late_risk = simulate_late_risk(conditions)
    simulated_day_for_shipping = simulate_day_for_shipping(conditions, simulated_late_risk)
    simulated_on_time = simulate_on_time(conditions, simulated_late_risk, simulated_day_for_shipping)
    simulated_results.append({
        f'simulated_{str_1}': simulated_late_risk,
        f'simulated_{str_2}': simulated_day_for_shipping,
        f'simulated_{str_3}': simulated_on_time
    })

# Convert simulated results to a DataFrame
simulated_results_df = pd.DataFrame(simulated_results)

# Evaluate the accuracy
correct_preds = [0] * 3  # For each target: late_risk, day_for_shipping, on_time
total_samples = [0] * 3
label_value_counts = {j: {} for j in range(3)}

for i, target in enumerate(feature_list.label[f'{dataset}']):
    true_values = encoded_test_data[target]
    predicted_values = simulated_results_df[f'simulated_{target}']

    correct_preds[i] += (true_values == predicted_values).sum()
    total_samples[i] += len(true_values)
    for value in predicted_values.values:
        if value not in label_value_counts[i]:
            label_value_counts[i][value] = 0
        label_value_counts[i][value] += 1
# Calculate accuracies
accuracies = [correct_preds[j] / total_samples[j] for j in range(3)]

print(f'===========mk distribution=============')
for j in range(3):
    print(f'{label_value_counts[j]}')
# Print simulated results and accuracies
print(simulated_results_df.head())
print("Accuracies:", accuracies)
print("All Accuracies:", (accuracies[0] + accuracies[1] + accuracies[2])/3)
print(time.time()-t)