import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import json
import os

from scipy.stats import norm, skew, kurtosis
from tqdm import tqdm


def read_json(file_path):
    """
    Reads a JSON file and returns the parsed data.
    """
    with open(file_path, "r", encoding='utf-8') as f:
        data = json.load(f)
    return data


def write_json(file_path, data):
    """
    Writes the given data to a JSON file.
    """
    with open(file_path, "w", encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=4)


def get_data(args):
    """
    Based on the experiment type, loads or generates the required dataset.
    """
    exp_type = args.data_type
    data = None

    if exp_type == 'e_commerce':
        # Directory setup for e_commerce data
        args.data_dir = os.path.join(args.data_dir, f'{exp_type}_data_{args.n_samples}')
        os.makedirs(args.data_dir, exist_ok=True)
        data_dir = os.path.join(args.data_dir, f'data.csv')

        # Load the configuration
        with open(os.path.join(args.config_dir, f"{exp_type}_config.json"), "r", encoding='utf-8') as f:
            config = json.load(f)

        # Generate data if not already present
        if not os.path.exists(data_dir):
            generate_e_commerce_data(config, args)
            print(f"Data generated and saved to {data_dir}")
        else:
            print(f"Data already exists at {data_dir}. Skipping generation.")

        data = pd.read_csv(data_dir)

    else:
        raise ValueError(f"Unknown experiment type: {exp_type}")

    return data, config


def generate_e_commerce_data(config, args):
    """
    Generates e-commerce data based on the configuration.
    Saves the generated data to a CSV file.
    """
    if not os.path.exists(args.data_dir):
        os.makedirs(args.data_dir)

    data_dir = args.data_dir
    n_samples = args.n_real_samples

    # Load parameters from the configuration
    age_mix_weights = config["user_age"]["mix_weights"]
    age_means = config["user_age"]["means"]
    age_stds = config["user_age"]["stds"]
    gender_categories = config["user_gender"]["categories"]
    gender_probs = config["user_gender"]["probs"]
    location_categories = config["location_tier"]["categories"]
    location_probs = config["location_tier"]["probs"]
    product_categories = config["product_category"]["categories"]
    product_probabilities = config["product_category"]["probabilities"]
    price_params = config["price"]
    payment_methods = config["payment_method"]["categories"]
    payment_probs = config["payment_method"]["probs"]

    # Sampling functions for user age, product category, price, etc.
    def sample_user_age(n):
        samples = []
        while len(samples) < n:
            comp = np.random.choice([0, 1, 2], p=age_mix_weights)
            candidate = np.random.normal(age_means[comp], age_stds[comp])
            if 18 <= candidate <= 90:
                samples.append(candidate)
        return np.array(samples)

    def get_age_group(age):
        if age < 35:
            return "young"
        elif age < 55:
            return "middle"
        else:
            return "old"

    def sample_product_category(age, gender):
        age_group = get_age_group(age)
        probs = product_probabilities[age_group][gender]
        return np.random.choice(product_categories, p=probs)

    def sample_price(product):
        mean, std = price_params[product]
        while True:
            candidate = np.random.normal(mean, std)
            if price_params["constraints"][0] < candidate < price_params["constraints"][1]:
                return candidate

    def sample_payment_method(location):
        probs = payment_probs[location]
        return np.random.choice(payment_methods, p=probs)

    # Data generation function
    def generate_data(n=1000):
        data = []
        for _ in range(n):
            age = sample_user_age(1)[0]
            gender = np.random.choice(gender_categories, p=gender_probs)
            location = np.random.choice(location_categories, p=location_probs)
            product = sample_product_category(age, gender)
            price = sample_price(product)
            payment = sample_payment_method(location)

            record = {
                "user_age": round(age, 1),
                "user_gender": gender,
                "location_tier": location,
                "product_category": product,
                "price": round(price, 2),
                "payment_method": payment
            }
            data.append(record)
        return pd.DataFrame(data)

    # Generate dataset and save it as a CSV file
    df = generate_data(n_samples)
    csv_path = os.path.join(data_dir, f"data.csv")
    df.to_csv(csv_path, index=False)

    # Visualization (example: Product_Category vs Age_Group)
    def get_age_group_for_df(age):
        if age < 35:
            return "young"
        elif age < 55:
            return "middle"
        else:
            return "old"

    df["Age_Group"] = df["user_age"].apply(get_age_group_for_df)
    crosstab_product_age = pd.crosstab(df["product_category"], df["Age_Group"])
    crosstab_location_payment = pd.crosstab(df["location_tier"], df["payment_method"])
    price_stats = df.groupby("product_category")["price"].agg(["mean", "std"])

    # Print the summary
    print("Crosstab: Product_Category vs Age_Group")
    print(crosstab_product_age, "\n")
    print("Crosstab: Location_Tier vs Payment_Method")
    print(crosstab_location_payment, "\n")
    print("Price statistics by Product_Category")
    print(price_stats, "\n")


if __name__ == "__main__":
    from args_utils import get_args
    args = get_args()
    data, config = get_data(args)
