from collections import defaultdict
import numpy as np
from scipy.optimize import linprog


class OptcmdpAlgorithm:
    def __init__(self, Cmdp, T, m, adv_data=None, delta=0.01):
        self.Cmdp = Cmdp
        self.X = Cmdp.X
        self.A = Cmdp.A
        self.K = T
        self.k = 0
        self.H = Cmdp.L
        self.I = m
        self.n = defaultdict(int)
        self.m = defaultdict(int)
        self.adv_data = adv_data
        self.res = False

        self.c_sum = defaultdict(float)
        self.d_sum = defaultdict(lambda: defaultdict(float))

        self.c_tilde = defaultdict(float)
        self.d_tilde = defaultdict(lambda: defaultdict(float))
        self.delta = delta

        self.p_bar = defaultdict(lambda: defaultdict(float))
        self.c_bar = defaultdict(float)
        self.d_bar = defaultdict(lambda: defaultdict(float))

        self.conf_p = defaultdict(lambda: defaultdict(lambda: (0, 1)))
        self.conf_c = defaultdict(lambda: (0, 1))
        self.conf_d = defaultdict(lambda: defaultdict(lambda: (0, 1)))

        self.h = 0

        self.cumul_reward = 0.0
        self.cumul_reward_list = []
        self.viol_list = []
        self.opt_term = defaultdict(float)
        self.violation = np.zeros(self.I)
        self.reward = defaultdict(float)
        self.unseen_reward = defaultdict(float)
        self.constraints = defaultdict(float)
        self.unseen_constr = defaultdict(float)

        self.played_q_list = []
        self.true_q_list = []

    def q_to_policy(self, q):
        policy = {}
        denom = defaultdict(float)
        for (x, a, x_prime), val in q.items():
            denom[x] += val
        for (x, a, x_prime), val in q.items():
            if denom[x] == 0:
                policy[(x, a)] = 1.0 / len(self.A)
            else:
                policy[(x, a)] = policy.get((x, a), 0.0) + val / denom[x]
        return policy

    def play_policy(self, policy):
        traj = []
        x = self.Cmdp.x0

        for h in range(self.H - 1):
            probs = np.array([policy.get((x, a), 0.0) for a in self.A])
            probs = probs / probs.sum()

            a = np.random.choice(self.A, p=probs)
            x_next = self.Cmdp.get_next_state(x, a)

            reward = self.Cmdp.get_reward(x, a, self.k)

            constraints = []
            for i in range(self.I):
                constraint_value = self.Cmdp.get_constraint(x, a, i, self.k)
                constraints.append(constraint_value)
            traj.append((x, a, x_next, reward, constraints))

            self.reward[(x, a, self.k)] = reward
            for i, value in enumerate(constraints):
                self.constraints[(x, a, i, self.k)] = value

            x = x_next

        if self.Cmdp.reward_type == "adversarial":
            for h in range(self.H - 1):
                X_h = self.Cmdp.layers[h]
                for x_ in X_h:
                    for a_ in self.A:
                        if (x_, a_, self.k) not in self.reward:
                            self.unseen_reward[(x_, a_, self.k)] = self.Cmdp.get_reward(x_, a_, self.k)
                            if self.Cmdp.constraint_type == "adversarial":
                                for i in range(self.I):
                                    key = (x_, a_, i, self.k)
                                    self.unseen_constr[key] = self.Cmdp.get_constraint(x_, a_, i, self.k)

        return traj

    def update_model(self, trajectory):
        for x, a, x_next, c, ds in trajectory:
            self.n[(x, a)] += 1
            self.m[(x, a, x_next)] += 1
            self.c_sum[(x, a)] += c
            for i, d in enumerate(ds):
                self.d_sum[(x, a)][i] += d

    def build_confidence_intervals(self):
        L_p = np.log((6 * len(self.X) * len(self.A) * self.H * self.K) / self.delta)
        L_c = 2 * np.log((6 * len(self.X) * len(self.A) * self.H * (self.I + 1) * self.K) / self.delta)

        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    count = max(self.n[(x, a)], 1)
                    self.c_bar[(x, a)] = self.c_sum[(x, a)] / count

                    for i in range(self.I):
                        self.d_bar[(x, a)][i] = self.d_sum[(x, a)][i] / count
                    for x_next in X_h1:
                        self.p_bar[(x, a)][x_next] = self.m[(x, a, x_next)] / count

                    beta_c = np.sqrt(L_c / count)
                    self.conf_c[(x, a)] = (np.clip(self.c_bar[(x, a)] - beta_c, 0.0, 1.0),
                                           np.clip(self.c_bar[(x, a)] + beta_c, 0.0, 1.0))

                    for i in range(self.I):
                        self.conf_d[(x, a)][i] = (np.clip(self.d_bar[(x, a)][i] - 2 * beta_c, -1.0, 1.0),
                                                  np.clip(self.d_bar[(x, a)][i] + 2 * beta_c, -1.0, 1.0))

                    for x_next in X_h1:
                        p_bar = self.p_bar[(x, a)][x_next]
                        beta_p= 2 * np.sqrt(p_bar * (1 - p_bar) * L_p / count) + (14 * L_p) / (
                                    3 * count)
                        self.conf_p[(x, a)][x_next] = (max(0, p_bar - beta_p), min(1, p_bar + beta_p))

    def solve_extended_LP(self):
        var_map = {}
        rev_var_map = []
        var_idx = 0

        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    for x_prime in X_h1:
                        key = (x, a, x_prime)
                        var_map[key] = var_idx
                        rev_var_map.append(key)
                        var_idx += 1

        num_vars = var_idx
        A_eq_rows = []
        b_eq_vals = []
        A_ub_rows = []
        b_ub_vals = []

        for h in range(1, self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_prev = self.Cmdp.layers[h - 1]
            X_next = self.Cmdp.layers[h + 1]

            for x in X_h:
                row = np.zeros(num_vars)
                for x_prev in X_prev:
                    for a in self.A:
                        key = (x_prev, a, x)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] += 1
                for x_next in X_next:
                    for a in self.A:
                        key = (x, a, x_next)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] -= 1
                A_eq_rows.append(row)
                b_eq_vals.append(0.0)

        for h in range(self.H - 1):
            row = np.zeros(num_vars)
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    for x_prime in X_h1:
                        key = (x, a, x_prime)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] += 1
            A_eq_rows.append(row)
            b_eq_vals.append(1.0)

        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            if len(X_h1) > 1:
                for x in X_h:
                    for a in self.A:
                        q_xa_indices_map = {}
                        for y in X_h1:
                            key_y = (x, a, y)
                            if key_y in var_map:
                                q_xa_indices_map[y] = var_map[key_y]
                        for x_prime in X_h1:
                            key_xp = (x, a, x_prime)
                            if key_xp in var_map:
                                idx_xp = var_map[key_xp]

                                lb, ub = self.conf_p[(x, a)][x_prime]

                                row1 = np.zeros(num_vars)
                                row1[idx_xp] = 1.0
                                for y, idx_y in q_xa_indices_map.items():
                                    row1[idx_y] -= ub
                                if not np.allclose(row1, 0.0):
                                    A_ub_rows.append(row1)
                                    b_ub_vals.append(0.0)

                                row2 = np.zeros(num_vars)
                                row2[idx_xp] = -1.0
                                for y, idx_y in q_xa_indices_map.items():
                                    row2[idx_y] += lb
                                if not np.allclose(row2, 0.0):
                                    A_ub_rows.append(row2)
                                    b_ub_vals.append(0.0)

        for i in range(self.I):
            row = np.zeros(num_vars)
            for h in range(self.H - 1):
                X_h = self.Cmdp.layers[h]
                for x in X_h:
                    for a in self.A:
                        d_lb = self.d_tilde[(x, a, i)]
                        X_k1 = self.Cmdp.layers[h + 1]
                        for x_prime in X_k1:
                            key = (x, a, x_prime)
                            if key in var_map:
                                idx = var_map[key]
                                row[idx] += d_lb
            if not np.all(row == 0):
                A_ub_rows.append(row)
                b_ub_vals.append(0.0)

        bounds = [(0, None)] * num_vars

        A_eq = np.array(A_eq_rows) if A_eq_rows else None
        b_eq = np.array(b_eq_vals) if b_eq_vals else None
        A_ub = np.array(A_ub_rows) if A_ub_rows else None
        b_ub = np.array(b_ub_vals) if b_ub_vals else None

        c = np.zeros(num_vars)
        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    reward_val = self.c_tilde.get((x, a), 0.0)
                    for x_prime in X_h1:
                        key = (x, a, x_prime)
                        if key in var_map:
                            idx = var_map[key]
                            c[idx] = -reward_val

        result = linprog(c,
                         A_ub=A_ub, b_ub=b_ub,
                         A_eq=A_eq, b_eq=b_eq,
                         bounds=bounds,
                         method='highs',
                         options={'disp': False}
                         )
        q_sol = {}

        if result.success:
            optimal_q_vector = result.x
            for idx, val in enumerate(optimal_q_vector):
                key = rev_var_map[idx]
                q_sol[key] = max(val, 0.0)
        else:
            print(f"Error: {result.message}")
        return q_sol

    def compute_estimates(self):
        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            for x in X_h:
                for a in self.A:
                    self.c_tilde[(x, a)] = self.conf_c[(x, a)][1]
                    for i in range(self.I):
                        self.d_tilde[(x, a, i)] = self.conf_d[(x, a)][i][0]

    def init(self):
        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_next = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    self.n[(x, a)] = 0
                    self.c_bar[(x, a)] = 0
                    for x_prime in X_next:
                        self.m[(x, a, x_prime)] = 0
                        self.p_bar[(x, a, x_prime)] = 1 / len(X_next)

    def compute_true_q(self, policy):
        q = defaultdict(float)
        mu = defaultdict(float)
        x0 = self.Cmdp.x0
        mu[x0] = 1.0
        for k in range(self.H - 1):
            X_k = self.Cmdp.layers[k]
            X_k1 = self.Cmdp.layers[k + 1]
            mu_next = defaultdict(float)
            for x in X_k:
                for a in self.A:
                    pi_val = policy.get((x, a), 0.0)
                    for x_prime in X_k1:
                        P_val = self.Cmdp.transitions.get(f"{x}_{a}", {}).get(str(x_prime), 0.0)
                        q_val = mu[x] * pi_val * P_val
                        q[(x, a, x_prime)] += q_val
                        mu_next[x_prime] += q_val
            mu = mu_next
        return q

    @staticmethod
    def get_marginal_q(q):
        q_marginal = defaultdict(float)
        for (x, a, x_prime), val in q.items():
            q_marginal[(x, a)] += val
        return q_marginal

    def compute_OPT(self, total_reward=None):
        var_map = {}
        rev_var_map = []
        var_idx = 0

        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    for x_prime in X_h1:
                        key = (x, a, x_prime)
                        var_map[key] = var_idx
                        rev_var_map.append(key)
                        var_idx += 1

        num_vars = var_idx
        A_eq_rows = []
        b_eq_vals = []
        A_ub_rows = []
        b_ub_vals = []

        for h in range(1, self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_prev = self.Cmdp.layers[h - 1]
            X_next = self.Cmdp.layers[h + 1]

            for x in X_h:
                row = np.zeros(num_vars)
                for x_prev in X_prev:
                    for a in self.A:
                        key = (x_prev, a, x)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] += 1
                for x_next in X_next:
                    for a in self.A:
                        key = (x, a, x_next)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] -= 1
                A_eq_rows.append(row)
                b_eq_vals.append(0.0)

        for h in range(self.H - 1):
            row = np.zeros(num_vars)
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    for x_prime in X_h1:
                        key = (x, a, x_prime)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] += 1
            A_eq_rows.append(row)
            b_eq_vals.append(1.0)

        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            if len(X_h1) > 1:
                for x in X_h:
                    for a in self.A:
                        p_true = self.Cmdp.transitions.get(f"{x}_{a}", {})
                        for x_prime in X_h1:
                            if (x, a, x_prime) in var_map:
                                row = np.zeros(num_vars)
                                idx = var_map[(x, a, x_prime)]
                                row[idx] -= 1
                                for xp in X_h1:
                                    if (x, a, xp) in var_map:
                                        idx_sum = var_map[(x, a, xp)]
                                        row[idx_sum] += p_true.get(str(x_prime), 0.0)
                                A_eq_rows.append(row)
                                b_eq_vals.append(0.0)

        if self.Cmdp.constraint_type == "stochastic":
            for i in range(self.I):
                row = np.zeros(num_vars)
                for h in range(self.H - 1):
                    X_h = self.Cmdp.layers[h]
                    for x in X_h:
                        for a in self.A:
                            mean_constraint = self.Cmdp.constraint_mean(x, a, i)
                            X_k1 = self.Cmdp.layers[h + 1]
                            for x_prime in X_k1:
                                key = (x, a, x_prime)
                                if key in var_map:
                                    idx = var_map[key]
                                    row[idx] += mean_constraint
                A_ub_rows.append(row)
                b_ub_vals.append(0.0)

        bounds = [(0, None)] * num_vars

        A_eq = np.array(A_eq_rows) if A_eq_rows else None
        b_eq = np.array(b_eq_vals) if b_eq_vals else None
        A_ub = np.array(A_ub_rows) if A_ub_rows else None
        b_ub = np.array(b_ub_vals) if b_ub_vals else None

        c = np.zeros(num_vars)
        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    if self.Cmdp.reward_type == "stochastic":
                        reward_val = self.Cmdp.reward_mean(x, a)
                    else:
                        reward_val = total_reward.get((x, a), 0.0) / self.K
                    for x_prime in X_h1:
                        key = (x, a, x_prime)
                        if key in var_map:
                            idx = var_map[key]
                            c[idx] = -reward_val

        result = linprog(c,
                         A_ub=A_ub, b_ub=b_ub,
                         A_eq=A_eq, b_eq=b_eq,
                         bounds=bounds,
                         method='highs',
                         options={'disp': False}
                         )

        q_opt = {}

        self.res = result.success

        if result.success:
            optimal_q_vector = result.x
            for idx, val in enumerate(optimal_q_vector):
                key = rev_var_map[idx]
                q_opt[key] = val
        else:
            print(f"Error: {result.message}")
        return q_opt

    def compute_regret(self):
        total_reward = defaultdict(float)
        for k in range(self.K):
            for h in range(self.H - 1):
                for x in self.Cmdp.layers[h]:
                    for a in self.A:
                        if (x, a, self.k) not in self.reward:
                            self.unseen_reward[(x, a, k)] = self.Cmdp.get_reward(x, a, k)
                        total_reward[(x, a)] += self.reward.get((x, a, k), self.unseen_reward.get((x, a, k)))

        q_star = self.compute_OPT(total_reward)
        q_star_marg = self.get_marginal_q(q_star)

        self.opt_term[0] = 0.0
        for k in range(self.K):
            for (x, a), q_val in q_star_marg.items():
                if self.Cmdp.reward_type == "stochastic":
                    self.opt_term[k] += (q_val * self.Cmdp.reward_mean(x, a))
                else:
                    self.opt_term[k] += (q_val * self.reward.get((x, a, k), self.unseen_reward.get((x, a, k))))
            self.opt_term[k + 1] = self.opt_term[k]

        q = self.played_q_list[self.K - 1]
        true_q = self.true_q_list[self.K - 1]

    def compute_cumul_viol(self, q, traj):
        q_marginal = self.get_marginal_q(q)

        q_matrix = np.zeros((len(self.X), len(self.A)))
        for (x, a), val in q_marginal.items():
            q_matrix[x, a] = val

        for i in range(self.I):
            g_i = defaultdict(lambda: 0.0)
            for x, a, _, _, constraints in traj:
                g_i[(x, a)] = constraints[i]

            for h in range(self.H - 1):
                X_h = self.Cmdp.layers[h]
                for x in X_h:
                    for a in self.A:
                        if (x, a) not in g_i:
                            if (x, a, i, self.k) not in self.unseen_constr:
                                self.unseen_constr[(x, a, i, self.k)] = self.Cmdp.get_constraint(x, a, i, self.k)
                                g_i[(x, a)] = self.unseen_constr[(x, a, i, self.k)]

            g_i_matrix = np.zeros((len(self.X), len(self.A)))
            for (x, a), val in g_i.items():
                g_i_matrix[x, a] = val

            self.violation[i] += np.sum(q_matrix * g_i_matrix)

    def run(self):
        self.init()
        print(f"-----OPTCMDP Algorithm-----")
        for k in range(self.K):
            self.k = k
            print(f"episode {k}")

            self.build_confidence_intervals()
            self.compute_estimates()

            q = self.solve_extended_LP()
            policy = self.q_to_policy(q)

            if self.adv_data is not None:
                self.adv_data.get_adversarial_data(policy)

            traj = self.play_policy(policy)
            self.update_model(traj)

            q_true = self.compute_true_q(policy)
            q_marg = self.get_marginal_q(q_true)

            self.played_q_list.append(q)
            self.true_q_list.append(q_true)

            incremental_reward = 0.0
            for (x, a), q_val in q_marg.items():
                if (x, a, k) not in self.reward:
                    self.unseen_reward[(x, a, k)] = self.Cmdp.get_reward(x, a, k)
                if self.Cmdp.reward_type == "stochastic":
                    reward_val = self.Cmdp.reward_mean(x, a)
                else:
                    reward_val = self.reward.get((x, a, k), self.unseen_reward.get((x, a, k)))
                incremental_reward += q_val * reward_val

            self.cumul_reward += incremental_reward
            self.cumul_reward_list.append(self.cumul_reward)

            self.compute_cumul_viol(q_true, traj)
            self.viol_list.append(np.max(self.violation))

        self.compute_regret()

