import numpy as np
from collections import defaultdict
from scipy.optimize import linprog


class OptPrimalDualAlgorithm:
    def __init__(self, Cmdp, T, m, adv_data=None, delta=0.01):
        self.Cmdp = Cmdp
        self.X = Cmdp.X
        self.A = Cmdp.A
        self.K = T
        self.H = Cmdp.L
        self.I = m
        self.k = 0

        self.n = defaultdict(int)
        self.m = defaultdict(int)

        self.c_sum = defaultdict(float)
        self.d_sum = defaultdict(lambda: defaultdict(float))

        self.c_tilde = defaultdict(float)
        self.d_tilde = defaultdict(lambda: defaultdict(float))
        self.delta = delta

        self.p_bar = defaultdict(lambda: defaultdict(float))
        self.c_bar = defaultdict(float)
        self.d_bar = defaultdict(lambda: defaultdict(float))

        self.lambda_ = np.zeros(self.I)
        self.rho = None
        self.t_lambda = None
        self.t_k = None

        self.cumul_reward = 0.0
        self.cumul_reward_list = []
        self.viol_list = []
        self.opt_term = defaultdict(float)
        self.violation = np.zeros(self.I)
        self.cost = defaultdict(float)
        self.unseen_cost = defaultdict(float)
        self.constraints = defaultdict(float)

        self.played_q_list = []
        self.true_q_list = []
        self.q_star = None
        self.q_star_marg = None
        self.c_q_star = None

    def compute_slater_policy_and_rho(self):
        self.q_star = self.compute_OPT(None)
        self.q_star_marg = self.get_marginal_q(self.q_star)
        self.c_q_star = sum(self.q_star_marg[(x, a)] * (1 - self.Cmdp.reward_mean(x, a)) for (x, a) in self.q_star_marg)

        epsilon = 0.0001

        var_map = {}
        rev_var_map = []
        var_idx = 0
        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h+1]
            for x in X_h:
                for a in self.A:
                    for x_prime in X_h1:
                        key = (x, a, x_prime)
                        var_map[key] = var_idx
                        rev_var_map.append(key)
                        var_idx += 1

        num_vars = var_idx
        A_eq_rows, b_eq_vals, A_ub_rows, b_ub_vals = [], [], [], []

        for h in range(1, self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_prev = self.Cmdp.layers[h - 1]
            X_next = self.Cmdp.layers[h + 1]

            for x in X_h:
                row = np.zeros(num_vars)
                for x_prev in X_prev:
                    for a in self.A:
                        key = (x_prev, a, x)
                        if key in var_map:
                            row[var_map[key]] += 1
                for x_next in X_next:
                    for a in self.A:
                        key = (x, a, x_next)
                        if key in var_map:
                            row[var_map[key]] -= 1
                A_eq_rows.append(row)
                b_eq_vals.append(0.0)

        for h in range(self.H - 1):
            row = np.zeros(num_vars)
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    for x_prime in X_h1:
                        key = (x, a, x_prime)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] += 1
            A_eq_rows.append(row)
            b_eq_vals.append(1.0)

        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            if len(X_h1) > 1:
                for x in X_h:
                    for a in self.A:
                        p_true = self.Cmdp.transitions.get(f"{x}_{a}", {})
                        for x_prime in X_h1:
                            if (x, a, x_prime) in var_map:
                                row = np.zeros(num_vars)
                                idx = var_map[(x, a, x_prime)]
                                row[idx] -= 1
                                for xp in X_h1:
                                    if (x, a, xp) in var_map:
                                        idx_sum = var_map[(x, a, xp)]
                                        row[idx_sum] += p_true.get(str(x_prime), 0.0)
                                A_eq_rows.append(row)
                                b_eq_vals.append(0.0)

        for i in range(self.I):
            row = np.zeros(num_vars)
            for h in range(self.H - 1):
                X_h = self.Cmdp.layers[h]
                for x in X_h:
                    for a in self.A:
                        mean_constraint = self.Cmdp.constraint_mean(x, a, i)
                        X_h1 = self.Cmdp.layers[h + 1]
                        for x_prime in X_h1:
                            key = (x, a, x_prime)
                            if key in var_map:
                                idx = var_map[key]
                                row[idx] += mean_constraint
            A_ub_rows.append(row)
            b_ub_vals.append(-epsilon)

        bounds = [(0, None)] * num_vars
        c = np.zeros(num_vars)

        A_eq = np.array(A_eq_rows)
        b_eq = np.array(b_eq_vals)
        A_ub = np.array(A_ub_rows)
        b_ub = np.array(b_ub_vals)
        res = linprog(c, A_eq=np.array(A_eq_rows), b_eq=np.array(b_eq_vals),
                      A_ub=np.array(A_ub_rows), b_ub=np.array(b_ub_vals),
                      bounds=bounds, method='highs')

        if not res.success:
            raise ValueError("Could not find Slater feasible point.")

        q_slater = {rev_var_map[i]: max(res.x[i], 0.0) for i in range(len(res.x))}
        q_slater_marg = self.get_marginal_q(q_slater)
        c_q_slater = sum(q_slater_marg[(x, a)] * (1 - self.Cmdp.reward_mean(x, a)) for (x, a) in q_slater_marg)

        denom = min([
            0 - sum(
                q_slater_marg.get((x, a), 0.0) * self.Cmdp.constraint_mean(x, a, i)
                for (x, a) in q_slater_marg
            ) for i in range(self.I)
        ])
        self.rho = (c_q_slater - self.c_q_star) / denom
        self.t_lambda = np.sqrt(self.H ** 2 * self.I * self.K) / (self.rho ** 2)
        self.t_k = np.sqrt(2 * np.log(len(self.A)) / ((self.H ** 2) * self.K * (1 + self.I * self.rho) ** 2))

    def uniform_policy(self):
        policy = {}
        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            for x in X_h:
                for a in self.A:
                    policy[(x,a)] = 1.0 / len(self.A)
        return policy

    def init(self):
        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_next = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    self.n[(x, a)] = 0
                    self.c_bar[(x, a)] = 0
                    for x_prime in X_next:
                        self.m[(x, a, x_prime)] = 0
                        self.p_bar[(x, a, x_prime)] = 1 / len(X_next)
                    for i in range(self.I):
                        self.d_bar[(x, a)][i] = -1

    def compute_bonus(self):
        L_p = np.log((6 * len(self.X) * len(self.A) * self.H * self.K) / self.delta)
        L_c = 2 * np.log((6 * len(self.X) * len(self.A) * self.H * (self.I + 1) * self.K) / self.delta)
        beta_p = defaultdict(float)
        b_c = defaultdict(float)
        b_d = defaultdict(float)

        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    count = max(self.n[(x, a)], 1)

                    beta_c = np.sqrt(L_c / count)
                    self.c_bar[(x, a)] = self.c_sum[(x, a)] / count

                    for i in range(self.I):
                        self.d_bar[(x, a)][i] = self.d_sum[(x, a)][i] / count

                    for x_next in X_h1:
                        self.p_bar[(x, a)][x_next] = self.m[(x, a, x_next)] / count
                        beta_p[(x,a,x_next)] = 2 * np.sqrt(self.p_bar[(x, a)][x_next] * (1 - self.p_bar[(x, a)][x_next]) * L_p / count) + (
                                14 * L_p) / (3 * count)

                    b_c[(x, a)] = beta_c + self.H * sum(beta_p[(x, a, x_next)] for x_next in X_h1)
                    b_d[(x, a)] = 2 * beta_c + self.H * sum(beta_p[(x, a, x_next)] for x_next in X_h1)

        return b_c, b_d

    def compute_estimates(self, b_c, b_d):
        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            for x in X_h:
                for a in self.A:
                    self.c_tilde[(x, a)] = self.c_bar[(x, a)] - b_c[(x, a)]
                    for i in range(self.I):
                        self.d_tilde[(x, a)][i] = self.d_bar[(x, a)][i] - b_d[(x, a)]

    def evaluate_policy(self, policy):
        self.Q_c = defaultdict(float)
        self.Q_d = [defaultdict(float) for _ in range(self.I)]

        V_c = defaultdict(float)
        V_d = [defaultdict(float) for _ in range(self.I)]
        for h in reversed(range(self.H - 1)):
            X_h = self.Cmdp.layers[h]
            for x in X_h:
                for a in self.A:
                    q_c = self.c_tilde[(x, a)]
                    for x_next in self.Cmdp.layers[h + 1]:
                        q_c += self.p_bar[(x, a)][x_next] * V_c[x_next]
                    self.Q_c[(x, a)] = max(q_c, 0.0)

                    for i in range(self.I):
                        q_d = self.d_tilde[(x, a)][i]
                        for x_next in self.Cmdp.layers[h + 1]:
                            q_d += self.p_bar[(x, a)][x_next] * V_d[i][x_next]
                        self.Q_d[i][(x, a)] = max(q_d, 0.0)

            for x in X_h:
                for a in self.A:
                    V_c[x] += self.Q_c[(x,a)] * policy[(x,a)]
                    for i in range(self.I):
                        V_d[i][x] += self.Q_d[i][(x, a)] * policy[(x,a)]

    def update_policy(self, policy):
        policy_next = {}
        total_Q = {}

        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            for x in X_h:
                denom = 0
                for a in self.A:
                    total_Q[(x,a)] = self.Q_c[(x, a)] + sum(self.lambda_[i] * self.Q_d[i][(x, a)] for i in range(self.I))
                    denom += policy[(x,a)] * np.exp(- self.t_k * total_Q[(x,a)])

                for a in self.A:
                    policy_next[(x,a)] = (policy[(x,a)] * np.exp(- self.t_k * total_Q[(x,a)])) / denom
        return policy_next

    def play_policy(self, policy):
        trajectory = []
        x = self.Cmdp.x0

        for h in range(self.H - 1):
            probs = np.array([policy.get((x, a)) for a in self.A])
            a = np.random.choice(self.A, p=probs)
            x_next = self.Cmdp.get_next_state(x, a)
            c = 1 - self.Cmdp.get_reward(x, a, self.k)
            ds = []
            for i in range(self.I):
                d = self.Cmdp.get_constraint(x, a, i, self.k)
                ds.append(d)
            trajectory.append((x, a, x_next, c, ds))
            x = x_next

        return trajectory

    def update_model(self, trajectory):
        for x, a, x_next, c, ds in trajectory:
            self.n[(x, a)] += 1
            self.m[(x, a, x_next)] += 1
            self.c_sum[(x, a)] += c
            for i, d in enumerate(ds):
                self.d_sum[(x, a)][i] += d

    def compute_true_q(self, policy):
        q = defaultdict(float)
        mu = defaultdict(float)
        x0 = self.Cmdp.x0
        mu[x0] = 1.0
        for k in range(self.H - 1):
            X_k = self.Cmdp.layers[k]
            X_k1 = self.Cmdp.layers[k + 1]
            mu_next = defaultdict(float)
            for x in X_k:
                for a in self.A:
                    pi_val = policy.get((x, a), 0.0)
                    for x_prime in X_k1:
                        P_val = self.Cmdp.transitions.get(f"{x}_{a}", {}).get(str(x_prime), 0.0)
                        q_val = mu[x] * pi_val * P_val
                        q[(x, a, x_prime)] += q_val
                        mu_next[x_prime] += q_val
            mu = mu_next
        return q

    @staticmethod
    def get_marginal_q(q):
        q_marginal = defaultdict(float)
        for (x, a, x_prime), val in q.items():
            q_marginal[(x, a)] += val
        return q_marginal

    def compute_OPT(self, total_reward):
        var_map = {}
        rev_var_map = []
        var_idx = 0

        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    for x_prime in X_h1:
                        key = (x, a, x_prime)
                        var_map[key] = var_idx
                        rev_var_map.append(key)
                        var_idx += 1

        num_vars = var_idx
        A_eq_rows = []
        b_eq_vals = []
        A_ub_rows = []
        b_ub_vals = []

        for h in range(1, self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_prev = self.Cmdp.layers[h - 1]
            X_next = self.Cmdp.layers[h + 1]

            for x in X_h:
                row = np.zeros(num_vars)
                for x_prev in X_prev:
                    for a in self.A:
                        key = (x_prev, a, x)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] += 1
                for x_next in X_next:
                    for a in self.A:
                        key = (x, a, x_next)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] -= 1
                A_eq_rows.append(row)
                b_eq_vals.append(0.0)

        for h in range(self.H - 1):
            row = np.zeros(num_vars)
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    for x_prime in X_h1:
                        key = (x, a, x_prime)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] += 1
            A_eq_rows.append(row)
            b_eq_vals.append(1.0)

        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            if len(X_h1) > 1:
                for x in X_h:
                    for a in self.A:
                        p_true = self.Cmdp.transitions.get(f"{x}_{a}", {})
                        for x_prime in X_h1:
                            if (x, a, x_prime) in var_map:
                                row = np.zeros(num_vars)
                                idx = var_map[(x, a, x_prime)]
                                row[idx] -= 1
                                for xp in X_h1:
                                    if (x, a, xp) in var_map:
                                        idx_sum = var_map[(x, a, xp)]
                                        row[idx_sum] += p_true.get(str(x_prime), 0.0)
                                A_eq_rows.append(row)
                                b_eq_vals.append(0.0)

        if self.Cmdp.constraint_type == "stochastic":
            for i in range(self.I):
                row = np.zeros(num_vars)
                for h in range(self.H - 1):
                    X_h = self.Cmdp.layers[h]
                    for x in X_h:
                        for a in self.A:
                            mean_constraint = self.Cmdp.constraint_mean(x, a, i)
                            X_h1 = self.Cmdp.layers[h + 1]
                            for x_prime in X_h1:
                                key = (x, a, x_prime)
                                if key in var_map:
                                    idx = var_map[key]
                                    row[idx] += mean_constraint
                A_ub_rows.append(row)
                b_ub_vals.append(0.0)

        bounds = [(0, None)] * num_vars

        A_eq = np.array(A_eq_rows) if A_eq_rows else None
        b_eq = np.array(b_eq_vals) if b_eq_vals else None
        A_ub = np.array(A_ub_rows) if A_ub_rows else None
        b_ub = np.array(b_ub_vals) if b_ub_vals else None

        c = np.zeros(num_vars)
        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    if self.Cmdp.reward_type == "stochastic":
                        reward_val = self.Cmdp.reward_mean(x, a)
                    else:
                        reward_val = total_reward.get(x,a) / self.K
                    for x_prime in X_h1:
                        key = (x, a, x_prime)
                        if key in var_map:
                            idx = var_map[key]
                            c[idx] = -reward_val

        result = linprog(c,
                         A_ub=A_ub, b_ub=b_ub,
                         A_eq=A_eq, b_eq=b_eq,
                         bounds=bounds,
                         method='highs',
                         options={'disp': False}
                         )
        q_opt = {}
        if result.success:
            optimal_q_vector = result.x
            for idx, val in enumerate(optimal_q_vector):
                key = rev_var_map[idx]
                q_opt[key] = max(val, 0.0)
        else:
            print(f"Error: {result.message}")
        return q_opt

    def compute_regret(self):
        total_reward = defaultdict(float)
        for k in range(self.K):
            for h in range(self.H - 1):
                for x in self.Cmdp.layers[h]:
                    for a in self.A:
                        if (x, a, self.k) not in self.cost:
                            self.unseen_cost[(x, a, k)] = 1 - self.Cmdp.get_reward(x, a, k)
                        loss = self.cost.get((x, a, k), self.unseen_cost.get((x, a, k)))
                        total_reward[(x, a)] += (1 - loss)

        self.opt_term[0] = 0.0
        for k in range(self.K):
            for (x, a), q_val in self.q_star_marg.items():
                if self.Cmdp.reward_type == "stochastic":
                    self.opt_term[k] += (q_val * self.Cmdp.reward_mean(x, a))
                else:
                    loss = self.cost.get((x, a, k), self.unseen_cost.get((x, a, k)))
                    self.opt_term[k] += (q_val * (1 - loss))
            self.opt_term[k + 1] = self.opt_term[k]

    def compute_cumul_viol(self, q, traj):
        q_marginal = self.get_marginal_q(q)

        q_matrix = np.zeros((len(self.X), len(self.A)))
        for (x, a), val in q_marginal.items():
            q_matrix[x, a] = val

        for i in range(self.I):
            g_i = defaultdict(lambda: 0.0)
            for x, a, _, _, constraints in traj:
                g_i[(x, a)] = constraints[i]

            for h in range(self.H - 1):
                X_h = self.Cmdp.layers[h]
                for x in X_h:
                    for a in self.A:
                        if (x, a) not in g_i:
                            g_i[(x, a)] = self.Cmdp.get_constraint(x, a, i, self.k)

            g_i_matrix = np.zeros((len(self.X), len(self.A)))

            for (x, a), val in g_i.items():
                g_i_matrix[x, a] = val

            self.violation[i] += np.sum(q_matrix * g_i_matrix)

    def run(self):
        self.compute_slater_policy_and_rho()
        self.init()
        policy = self.uniform_policy()
        for k in range(self.K):
            self.k = k
            print(f"Round {k}")
            b_c, b_d = self.compute_bonus()
            self.compute_estimates(b_c, b_d)
            self.evaluate_policy(policy)
            policy_next = self.update_policy(policy)
            traj = self.play_policy(policy_next)
            self.update_model(traj)
            for i in range(self.I):
                viol_i = sum(self.d_bar[(x, a)][i] * policy_next.get((x, a), 0.0) for (x, a) in policy_next)
                self.lambda_[i] = np.clip(self.lambda_[i] + viol_i / self.t_lambda, 0, self.rho)

            q_true = self.compute_true_q(policy_next)
            self.true_q_list.append(q_true)
            q_marg = self.get_marginal_q(q_true)
            incremental_reward = 0.0
            for (x, a), q_val in q_marg.items():
                if (x, a, k) not in self.cost:
                    self.unseen_cost[(x, a, k)] = 1 - self.Cmdp.get_reward(x, a, k)
                if self.Cmdp.reward_type == "stochastic":
                    incremental_reward += q_val * self.Cmdp.reward_mean(x, a)
                else:
                    loss = self.cost.get((x, a, k), self.unseen_cost.get((x, a, k)))
                    incremental_reward += q_val * (1 - loss)
            self.cumul_reward += incremental_reward
            self.cumul_reward_list.append(self.cumul_reward)

            self.compute_cumul_viol(q_true, traj)
            self.viol_list.append(np.max(self.violation))

            policy = policy_next

        self.compute_regret()
