import argparse
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import product
from concurrent.futures import ProcessPoolExecutor, as_completed
from src.Cmdp import Cmdp
from src.WC_OPS import Algorithm
from src.OPTCmdp import OptcmdpAlgorithm
from src.OPTPDCmdp import OptPrimalDualAlgorithm
from src.GreedyAlgorithm import GreedyAlgorithm
from utils.generate_CMDP_json import generate_cmdp_json
from utils.AdversarialDataGenerator import AdversarialDataGenerator
from utils.functions import StochasticSampler


def make_algorithm(cmdp, T, m, adv_data=None):
    return Algorithm(cmdp, T, m, adv_data)


def make_optcmdp_algorithm(cmdp, T, m, adv_data=None):
    return OptcmdpAlgorithm(cmdp, T, m, adv_data)


def make_optprimaldual_algorithm(cmdp, T, m, adv_data=None):
    return OptPrimalDualAlgorithm(cmdp, T, m, adv_data)


def make_greedy(cmdp, T, m, adv_data=None):
    return GreedyAlgorithm(cmdp, T, m, adv_data)


def load_stochastic_samplers(X, A, T, m, reward_type, constraint_type, seed, force=False):
    reward_path = "data/reward_means.json"
    constraint_path = "data/constraint_means.json"

    def check_consistency(path, expected_len):
        if not os.path.exists(path):
            return False
        with open(path) as f:
            data = json.load(f)
        return len(data) == expected_len

    expected_rewards = X * A
    expected_constraints = X * A * m

    need_regenerate = force
    if reward_type == "stoc" and not check_consistency(reward_path, expected_rewards):
        need_regenerate = True
    if constraint_type == "stoc" and not check_consistency(constraint_path, expected_constraints):
        need_regenerate = True

    if need_regenerate:
        reward_means = {}
        constraint_means = {}
        for x, a in product(range(X), range(A)):
            reward_means[str((x, a))] = float(np.round(np.random.uniform(0, 1), 3))
            for i in range(m):
                constraint_means[str((x, a, i))] = float(np.round(np.random.uniform(-1, 1), 3))
        os.makedirs("data", exist_ok=True)
        with open(reward_path, "w") as f:
            json.dump(reward_means, f, indent=2)
        with open(constraint_path, "w") as f:
            json.dump(constraint_means, f, indent=2)

    with open(reward_path) as f:
        reward_means = {eval(k): v for k, v in json.load(f).items()}
    with open(constraint_path) as f:
        constraint_means = {eval(k): v for k, v in json.load(f).items()}

    reward_sampler = StochasticSampler(reward_means, mode="bernoulli", T=T, seed=seed)
    constraint_sampler = StochasticSampler(constraint_means, mode="bernoulli_sign", T=T, seed=seed)
    return reward_sampler, constraint_sampler


def run_single_experiment(cmdp_path, reward_type, constraint_type, T, m, X, A, run_seed):
    cmdp = Cmdp(cmdp_path)

    adv_data = None
    if reward_type == "stoc" or constraint_type == "stoc":
        reward_sampler, constraint_sampler = load_stochastic_samplers(X, A, T, m, reward_type, constraint_type, run_seed)

    if reward_type == "stoc":
        cmdp.reward_type = "stochastic"
        cmdp._reward_sampler = reward_sampler
        cmdp.get_reward = lambda x, a, t: reward_sampler(x, a, t=t)
        cmdp.reward_mean = lambda x, a: reward_sampler.get_mean(x, a)
    else:
        cmdp.reward_type = "adversarial"

        def adversarial_reward(x, a, t):
            return cmdp.adv_data.current_reward_vectors[x][a] * cmdp.adv_data.true_reward_vectors[x][a]
        cmdp.get_reward = adversarial_reward

    if constraint_type == "stoc":
        cmdp.constraint_type = "stochastic"
        cmdp._constraint_sampler = constraint_sampler
        cmdp.get_constraint = lambda x, a, i, t: constraint_sampler(x, a, i, t=t)
        cmdp.constraint_mean = lambda x, a, i: constraint_sampler.get_mean(x, a, i)
    else:
        cmdp.constraint_type = "adversarial"

        det_actions = {tuple(pair) for pair in cmdp.det_actions}
        rho_min = 0.05

        def adversarial_constraint(x, a, i, t):
            base_val = cmdp.adv_data.current_constraint_vectors[x][i][a] * cmdp.adv_data.true_constraint_vectors[x][i][a]
            if (x, a) in det_actions:
                return min(-rho_min, base_val)
            return base_val

        cmdp.get_constraint = adversarial_constraint

    if reward_type == "adv" or constraint_type == "adv":
        adv_data = AdversarialDataGenerator(
            X, A, m, eta=0.01,
            adv_reward=(reward_type == "adv"),
            adv_constraints=(constraint_type == "adv")
        )
        cmdp.adv_data = adv_data

    algorithm_classes = [("Greedy", make_greedy), ("WC-OPS", make_algorithm)]

    if cmdp.reward_type == "stochastic":
        algorithm_classes += [
            ("OptCMDP", make_optcmdp_algorithm),
            ("OptPrimalDual", make_optprimaldual_algorithm)
        ]
    elif cmdp.reward_type == "adversarial":
        if cmdp.constraint_type == "stochastic":
            algorithm_classes += [
                ("OptCMDP", make_optcmdp_algorithm),
            ]

    results = {}
    for name, AlgoFactory in algorithm_classes:
        try:
            algo = AlgoFactory(cmdp, T, m, adv_data)
            algo.run()
            results[f"regret_{name}"] = [algo.opt_term[t] - algo.cumul_reward_list[t] for t in range(T)]
            results[f"violation_{name}"] = algo.viol_list
        except Exception as e:
            raise RuntimeError(f"Algorithm {name} failed due to LP error: {e}")

    return results, len(results) // 2


def ensure_cmdp_exists(path, X, A, L, constraint_type):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    generate_cmdp_json(path, list(range(X)), list(range(A)), L, constraint_type)


def params_differ(X, A, L, m, cmdp_path="data/cmdp.json", constraint_path="data/constraint_means.json", reward_path="data/reward_means.json"):
    if not os.path.exists(cmdp_path) or not os.path.exists(constraint_path) or not os.path.exists(reward_path):
        return True
    cmdp = Cmdp(cmdp_path)
    if len(cmdp.X)!=X or len(cmdp.A)!=A or cmdp.L!=L:
        return True

    with open(constraint_path) as f:
        data = json.load(f)
    keys = [eval(k) for k in data.keys()]
    xs = {k[0] for k in keys}
    as_ = {k[1] for k in keys}
    ms = {k[2] for k in keys}
    if xs != set(range(X)) or as_ != set(range(A)) or ms != set(range(m)):
        return True

    with open(reward_path) as f:
        data = json.load(f)
    keys = [eval(k) for k in data.keys()]
    xs = {k[0] for k in keys}
    as_ = {k[1] for k in keys}
    if xs != set(range(X)) or as_ != set(range(A)):
        return True

    return False


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cmdp_json", type=str, default="data/cmdp.json")
    parser.add_argument("--reward_type", type=str, choices=["stoc", "adv"], default="stoc")
    parser.add_argument("--constraint_type", type=str, choices=["stoc", "adv"], default="stoc")
    parser.add_argument("--T", type=int, default=500000)
    parser.add_argument("--m", type=int, default=3)
    parser.add_argument("--n_repeats", type=int, default=10)
    parser.add_argument("--out_csv", type=str, default="results/results.csv")
    parser.add_argument("--X", type=int, default=3)
    parser.add_argument("--A", type=int, default=3)
    parser.add_argument("--L", type=int, default=3)
    args = parser.parse_args()

    ensure_cmdp_exists(args.cmdp_json, args.X, args.A, args.L, args.constraint_type)
    cmdp = Cmdp(args.cmdp_json)
    X, A, L = len(cmdp.X), len(cmdp.A), cmdp.L
    os.makedirs(os.path.dirname(args.out_csv), exist_ok=True)

    force_generate = params_differ(args.X, args.A, args.L, args.m)

    if args.reward_type == "stoc" and args.constraint_type == "stoc":
        feasible = False
        attempt = 0
        while not feasible:
            attempt += 1
            reward_sampler, constraint_sampler = load_stochastic_samplers(
                X, A, args.T, args.m,
                args.reward_type, args.constraint_type,
                seed=np.random.randint(1e6),
                force=force_generate
            )

            cmdp = Cmdp(args.cmdp_json)
            cmdp.reward_type = "stochastic"
            cmdp._reward_sampler = reward_sampler
            cmdp.get_reward = lambda x, a, t: reward_sampler(x, a, t=t)
            cmdp.reward_mean = lambda x, a: reward_sampler.get_mean(x, a)
            cmdp.constraint_type = "stochastic"
            cmdp._constraint_sampler = constraint_sampler
            cmdp.get_constraint = lambda x, a, i, t: constraint_sampler(x, a, i, t=t)
            cmdp.constraint_mean = lambda x, a, i: constraint_sampler.get_mean(x, a, i)

            opt_solver = OptcmdpAlgorithm(cmdp, args.T, args.m, None)
            opt_solver.compute_OPT()

            if not opt_solver.res:
                if os.path.exists("data/constraint_means.json"):
                    os.remove("data/constraint_means.json")
                force_generate = True
                continue
            else:
                feasible = True

    results_all = []
    num_algorithms = None

    with ProcessPoolExecutor() as executor:
        active_futures = {}
        total_submitted = 0

        def submit_job():
            nonlocal total_submitted
            run_seed = np.random.randint(1e6)
            fut = executor.submit(
                run_single_experiment,
                args.cmdp_json,
                args.reward_type,
                args.constraint_type,
                args.T,
                args.m,
                X,
                A,
                run_seed
            )
            active_futures[fut] = run_seed
            total_submitted += 1

        max_workers = executor._max_workers
        for _ in range(min(max_workers, args.n_repeats)):
            submit_job()

        while len(results_all) < args.n_repeats:
            for future in as_completed(list(active_futures.keys())):
                seed = active_futures.pop(future)
                try:
                    result, num_algorithms = future.result()
                    results_all.append(result)
                    print(f"Repeat {len(results_all)}/{args.n_repeats} completed successfully.")
                except RuntimeError as e:
                    print(f"Repeat {len(results_all)+1} failed with seed {seed}: {e} — retrying...")
                    submit_job()
                    continue

                if len(results_all) + len(active_futures) < args.n_repeats:
                    submit_job()
                break

    algorithm_names = list({key.split("_")[1] for key in results_all[0].keys()})
    df_data = {"t": np.arange(args.T)}
    for name in algorithm_names:
        regrets = np.array([r[f"regret_{name}"] for r in results_all])
        violations = np.array([r[f"violation_{name}"] for r in results_all])
        df_data[f"regret_mean_{name}"] = np.mean(regrets, axis=0)
        df_data[f"regret_ci_{name}"] = 1.96 * np.std(regrets, axis=0) / np.sqrt(args.n_repeats)
        df_data[f"violation_mean_{name}"] = np.mean(violations, axis=0)
        df_data[f"violation_ci_{name}"] = 1.96 * np.std(violations, axis=0) / np.sqrt(args.n_repeats)

    df = pd.DataFrame(df_data)
    df.to_csv(args.out_csv, index=False)
    print(f"\nResults saved to {args.out_csv}")

    plt.figure()
    for name in algorithm_names:
        plt.plot(df["t"], df[f"regret_mean_{name}"], label=f"{name} Regret")
        plt.fill_between(df["t"], df[f"regret_mean_{name}"] - df[f"regret_ci_{name}"],
                         df[f"regret_mean_{name}"] + df[f"regret_ci_{name}"], alpha=0.3)
    plt.xlabel("t")
    plt.ylabel("Regret")
    plt.title("Average Regret")
    plt.legend()
    plt.grid()
    plt.savefig("results/regret_plot_all.png")
    plt.show()

    plt.figure()
    for name in algorithm_names:
        plt.plot(df["t"], df[f"violation_mean_{name}"], label=f"{name} Violation")
        plt.fill_between(df["t"], df[f"violation_mean_{name}"] - df[f"violation_ci_{name}"],
                         df[f"violation_mean_{name}"] + df[f"violation_ci_{name}"], alpha=0.3)
    plt.xlabel("t")
    plt.ylabel("Violation")
    plt.title("Average Violation")
    plt.legend()
    plt.grid()
    plt.savefig("results/violation_plot_all.png")
    plt.show()


if __name__ == "__main__":
    main()


