from collections import defaultdict
import numpy as np
from scipy.optimize import linprog


class GreedyAlgorithm:
    def __init__(self, Cmdp, T, m, adv_data=None, delta=0.01):
        self.Cmdp = Cmdp
        self.X = Cmdp.X
        self.A = Cmdp.A
        self.K = T
        self.k = 0
        self.H = Cmdp.L
        self.I = m
        self.n = defaultdict(int)
        self.m = defaultdict(int)
        self.adv_data = adv_data

        self.c_sum = defaultdict(float)
        self.d_sum = defaultdict(lambda: defaultdict(float))

        self.p_bar = defaultdict(lambda: defaultdict(float))
        self.c_bar = defaultdict(float)
        self.d_bar = defaultdict(lambda: defaultdict(float))

        self.h = 0

        self.cumul_reward = 0.0
        self.cumul_reward_list = []
        self.viol_list = []
        self.opt_term = defaultdict(float)
        self.violation = np.zeros(self.I)
        self.reward = defaultdict(float)
        self.unseen_reward = defaultdict(float)
        self.constraints = defaultdict(float)
        self.unseen_constr = defaultdict(float)

        self.played_q_list = []
        self.true_q_list = []

    def init_uniform_q(self):
        q = defaultdict(float)
        for k in range(self.H - 1):
            X_k = self.Cmdp.layers[k]
            X_k1 = self.Cmdp.layers[k+1]
            Z = len(X_k) * len(self.A) * len(X_k1)
            for x in X_k:
                for a in self.A:
                    for x_prime in X_k1:
                        q[(x, a, x_prime)] = 1.0 / Z
        return q

    def q_to_policy(self, q):
        policy = {}
        denom = defaultdict(float)
        for (x, a, x_prime), val in q.items():
            denom[x] += val
        for (x, a, x_prime), val in q.items():
            if denom[x] == 0:
                policy[(x, a)] = 1.0 / len(self.A)
            else:
                policy[(x, a)] = policy.get((x, a), 0.0) + val / denom[x]
        return policy

    def play_policy(self, policy):
        traj = []
        x = self.Cmdp.x0

        for h in range(self.H - 1):
            probs = np.array([policy.get((x, a), 0.0) for a in self.A])
            probs = probs / probs.sum()

            a = np.random.choice(self.A, p=probs)
            x_next = self.Cmdp.get_next_state(x, a)

            reward = self.Cmdp.get_reward(x, a, self.k)

            constraints = []
            for i in range(self.I):
                constraint_value = self.Cmdp.get_constraint(x, a, i, self.k)
                constraints.append(constraint_value)
            traj.append((x, a, x_next, reward, constraints))

            self.reward[(x, a, self.k)] = reward
            for i, value in enumerate(constraints):
                self.constraints[(x, a, i, self.k)] = value
            x = x_next

        if self.Cmdp.reward_type == "adversarial":
            for h in range(self.H - 1):
                X_h = self.Cmdp.layers[h]
                for x_ in X_h:
                    for a_ in self.A:
                        if (x_, a_, self.k) not in self.reward:
                            self.unseen_reward[(x_, a_, self.k)] = self.Cmdp.get_reward(x_, a_, self.k)
                            if self.Cmdp.constraint_type == "adversarial":
                                for i in range(self.I):
                                    key = (x_, a_, i, self.k)
                                    self.unseen_constr[key] = self.Cmdp.get_constraint(x_, a_, i, self.k)

        return traj

    def update_model(self, trajectory):
        for x, a, x_next, c, ds in trajectory:
            self.n[(x, a)] += 1
            self.m[(x, a, x_next)] += 1
            self.c_sum[(x, a)] += c
            for i, d in enumerate(ds):
                self.d_sum[(x, a)][i] += d

    def update_q(self, q):
        var_map = {}
        rev_var_map = []
        var_idx = 0

        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    for x_prime in X_h1:
                        key = (x, a, x_prime)
                        var_map[key] = var_idx
                        rev_var_map.append(key)
                        var_idx += 1

        num_vars = var_idx
        A_eq_rows = []
        b_eq_vals = []
        A_ub_rows = []
        b_ub_vals = []

        for h in range(1, self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_prev = self.Cmdp.layers[h - 1]
            X_next = self.Cmdp.layers[h + 1]

            for x in X_h:
                row = np.zeros(num_vars)
                for x_prev in X_prev:
                    for a in self.A:
                        key = (x_prev, a, x)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] += 1
                for x_next in X_next:
                    for a in self.A:
                        key = (x, a, x_next)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] -= 1
                A_eq_rows.append(row)
                b_eq_vals.append(0.0)

        for h in range(self.H - 1):
            row = np.zeros(num_vars)
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    for x_prime in X_h1:
                        key = (x, a, x_prime)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] += 1
            A_eq_rows.append(row)
            b_eq_vals.append(1.0)

        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            if len(X_h1) > 1:
                for x in X_h:
                    for a in self.A:
                        for x_prime in X_h1:
                            if self.p_bar[(x, a)][x_prime] == 0:
                                continue
                            if (x, a, x_prime) in var_map:
                                row = np.zeros(num_vars)
                                idx = var_map[(x, a, x_prime)]
                                row[idx] -= 1
                                for xp in X_h1:
                                    if (x, a, xp) in var_map:
                                        idx_sum = var_map[(x, a, xp)]
                                        row[idx_sum] += self.p_bar[(x, a)][x_prime]
                                A_eq_rows.append(row)
                                b_eq_vals.append(0.0)

        for i in range(self.I):
            row = np.zeros(num_vars)
            for h in range(self.H - 1):
                X_h = self.Cmdp.layers[h]
                for x in X_h:
                    for a in self.A:
                        d_lb = self.d_bar[(x, a)][i]
                        X_k1 = self.Cmdp.layers[h + 1]
                        for x_prime in X_k1:
                            key = (x, a, x_prime)
                            if key in var_map:
                                idx = var_map[key]
                                row[idx] += d_lb
            if not np.all(row == 0):
                A_ub_rows.append(row)
                b_ub_vals.append(0.0)

        bounds = [(0, None)] * num_vars

        A_eq = np.array(A_eq_rows) if A_eq_rows else None
        b_eq = np.array(b_eq_vals) if b_eq_vals else None
        A_ub = np.array(A_ub_rows) if A_ub_rows else None
        b_ub = np.array(b_ub_vals) if b_ub_vals else None

        c = np.zeros(num_vars)
        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    reward_val = self.c_bar.get((x, a), 0.0)
                    for x_prime in X_h1:
                        key = (x, a, x_prime)
                        if key in var_map:
                            idx = var_map[key]
                            c[idx] = -reward_val

        result = linprog(c,
                         A_ub=A_ub, b_ub=b_ub,
                         A_eq=A_eq, b_eq=b_eq,
                         bounds=bounds,
                         method='highs',
                         options={'disp': False}
                         )

        q_sol = {}
        if result.success:
            optimal_q_vector = result.x
            for idx, val in enumerate(optimal_q_vector):
                key = rev_var_map[idx]
                q_sol[key] = max(val, 0.0)
        else:
            print(f"Error: {result.message}")
        return q_sol

    def compute_estimates(self):
        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    count = max(self.n[(x, a)], 1)
                    self.c_bar[(x, a)] = self.c_sum[(x, a)] / count

                    for i in range(self.I):
                        self.d_bar[(x, a)][i] = self.d_sum[(x, a)][i] / count

                    for x_next in X_h1:
                        self.p_bar[(x, a)][x_next] = self.m[(x, a, x_next)] / count

    def init(self):
        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_next = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    self.n[(x, a)] = 0
                    self.c_bar[(x, a)] = 0
                    for x_prime in X_next:
                        self.m[(x, a, x_prime)] = 0
                        self.p_bar[(x, a)][x_prime] = 1 / len(X_next)

                    for i in range(self.I):
                        self.d_bar[(x, a)][i] = -1
                        self.d_sum[(x, a)][i] = -1

    def compute_true_q(self, policy):
        q = defaultdict(float)
        mu = defaultdict(float)
        x0 = self.Cmdp.x0
        mu[x0] = 1.0
        for k in range(self.H - 1):
            X_k = self.Cmdp.layers[k]
            X_k1 = self.Cmdp.layers[k + 1]
            mu_next = defaultdict(float)
            for x in X_k:
                for a in self.A:
                    pi_val = policy.get((x, a), 0.0)
                    for x_prime in X_k1:
                        P_val = self.Cmdp.transitions.get(f"{x}_{a}", {}).get(str(x_prime), 0.0)
                        q_val = mu[x] * pi_val * P_val
                        q[(x, a, x_prime)] += q_val
                        mu_next[x_prime] += q_val
            mu = mu_next
        return q

    @staticmethod
    def get_marginal_q(q):
        q_marginal = defaultdict(float)
        for (x, a, x_prime), val in q.items():
            q_marginal[(x, a)] += val
        return q_marginal

    def compute_OPT(self, total_reward):
        var_map = {}
        rev_var_map = []
        var_idx = 0

        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    for x_prime in X_h1:
                        key = (x, a, x_prime)
                        var_map[key] = var_idx
                        rev_var_map.append(key)
                        var_idx += 1

        num_vars = var_idx
        A_eq_rows = []
        b_eq_vals = []
        A_ub_rows = []
        b_ub_vals = []

        for h in range(1, self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_prev = self.Cmdp.layers[h - 1]
            X_next = self.Cmdp.layers[h + 1]

            for x in X_h:
                row = np.zeros(num_vars)
                for x_prev in X_prev:
                    for a in self.A:
                        key = (x_prev, a, x)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] += 1
                for x_next in X_next:
                    for a in self.A:
                        key = (x, a, x_next)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] -= 1
                A_eq_rows.append(row)
                b_eq_vals.append(0.0)

        for h in range(self.H - 1):
            row = np.zeros(num_vars)
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    for x_prime in X_h1:
                        key = (x, a, x_prime)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] += 1
            A_eq_rows.append(row)
            b_eq_vals.append(1.0)

        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            if len(X_h1) > 1:
                for x in X_h:
                    for a in self.A:
                        p_true = self.Cmdp.transitions.get(f"{x}_{a}", {})
                        for x_prime in X_h1:
                            if (x, a, x_prime) in var_map:
                                row = np.zeros(num_vars)
                                idx = var_map[(x, a, x_prime)]
                                row[idx] -= 1
                                for xp in X_h1:
                                    if (x, a, xp) in var_map:
                                        idx_sum = var_map[(x, a, xp)]
                                        row[idx_sum] += p_true.get(str(x_prime), 0.0)
                                A_eq_rows.append(row)
                                b_eq_vals.append(0.0)

        if self.Cmdp.constraint_type == "stochastic":
            for i in range(self.I):
                row = np.zeros(num_vars)
                for h in range(self.H - 1):
                    X_h = self.Cmdp.layers[h]
                    for x in X_h:
                        for a in self.A:
                            mean_constraint = self.Cmdp.constraint_mean(x, a, i)
                            X_k1 = self.Cmdp.layers[h + 1]
                            for x_prime in X_k1:
                                key = (x, a, x_prime)
                                if key in var_map:
                                    idx = var_map[key]
                                    row[idx] += mean_constraint
                A_ub_rows.append(row)
                b_ub_vals.append(0.0)

        bounds = [(0, None)] * num_vars

        A_eq = np.array(A_eq_rows) if A_eq_rows else None
        b_eq = np.array(b_eq_vals) if b_eq_vals else None
        A_ub = np.array(A_ub_rows) if A_ub_rows else None
        b_ub = np.array(b_ub_vals) if b_ub_vals else None

        c = np.zeros(num_vars)
        for h in range(self.H - 1):
            X_h = self.Cmdp.layers[h]
            X_h1 = self.Cmdp.layers[h + 1]
            for x in X_h:
                for a in self.A:
                    if self.Cmdp.reward_type == "stochastic":
                        reward_val = self.Cmdp.reward_mean(x, a)
                    else:
                        reward_val = total_reward.get((x, a), 0.0) / self.K
                    for x_prime in X_h1:
                        key = (x, a, x_prime)
                        if key in var_map:
                            idx = var_map[key]
                            c[idx] = -reward_val

        result = linprog(c,
                         A_ub=A_ub, b_ub=b_ub,
                         A_eq=A_eq, b_eq=b_eq,
                         bounds=bounds,
                         method='highs',
                         options={'disp': False}
                         )
        q_opt = {}
        if result.success:
            optimal_q_vector = result.x
            for idx, val in enumerate(optimal_q_vector):
                key = rev_var_map[idx]
                q_opt[key] = val
        else:
            print(f"Error: {result.message}")
        return q_opt

    def compute_rho(self):
        def neg_g(x, a, i, t):
            return - self.constraints.get((x, a, i, t),
                                          self.unseen_constr.get((x, a, i, t)))

        det_actions = self.Cmdp.det_actions
        rho = 0.05

        candidate_paths = [det_actions]

        for path in candidate_paths:
            path_min = float("inf")
            for (x, a) in path:
                pair_min = float("inf")
                for t in range(self.K):
                    for i in range(self.I):
                        val = neg_g(x, a, i, t)
                        pair_min = min(pair_min, val)
                path_min = min(path_min, pair_min)
                if path_min <= rho:
                    break
            rho = max(rho, path_min)

        alpha = 1.0 / (1.0 + rho)
        return rho, alpha

    def compute_regret(self):
        total_reward = defaultdict(float)
        for k in range(self.K):
            for h in range(self.H - 1):
                for x in self.Cmdp.layers[h]:
                    for a in self.A:
                        if (x, a, self.k) not in self.reward:
                            self.unseen_reward[(x, a, k)] = self.Cmdp.get_reward(x, a, k)
                        total_reward[(x, a)] += self.reward.get((x, a, k), self.unseen_reward.get((x, a, k)))

        q_star = self.compute_OPT(total_reward)
        q_star_marg = self.get_marginal_q(q_star)

        if self.Cmdp.constraint_type == "adversarial":
            rho, alpha = self.compute_rho()

        self.opt_term[0] = 0.0
        for k in range(self.K):
            for (x, a), q_val in q_star_marg.items():
                if self.Cmdp.constraint_type == "stochastic":
                    if self.Cmdp.reward_type == "stochastic":
                        self.opt_term[k] += (q_val * self.Cmdp.reward_mean(x, a))
                    else:
                        self.opt_term[k] += (q_val * self.reward.get((x, a, k), 0))
                else:
                    self.opt_term[k] += (q_val * self.reward.get((x, a, k), self.unseen_reward.get((x, a, k)))) * alpha
            self.opt_term[k + 1] = self.opt_term[k]

        q = self.played_q_list[self.K - 1]
        true_q = self.true_q_list[self.K - 1]

    def compute_cumul_viol(self, q, traj):
        q_marginal = self.get_marginal_q(q)

        q_matrix = np.zeros((len(self.X), len(self.A)))
        for (x, a), val in q_marginal.items():
            q_matrix[x, a] = val

        for i in range(self.I):
            g_i = defaultdict(lambda: 0.0)
            for x, a, _, _, constraints in traj:
                g_i[(x, a)] = constraints[i]

            for h in range(self.H - 1):
                X_h = self.Cmdp.layers[h]
                for x in X_h:
                    for a in self.A:
                        if (x,a) not in g_i:
                            if (x, a, i, self.k) not in self.unseen_constr:
                                self.unseen_constr[(x, a, i, self.k)] = self.Cmdp.get_constraint(x, a, i, self.k)
                                g_i[(x, a)] = self.unseen_constr[(x, a, i, self.k)]

            g_i_matrix = np.zeros((len(self.X), len(self.A)))
            for (x, a), val in g_i.items():
                g_i_matrix[x, a] = val

            self.violation[i] += np.sum(q_matrix * g_i_matrix)

    def run(self):
        self.init()
        q = self.init_uniform_q()
        print(f"-----Greedy Algorithm-----")
        for k in range(self.K):
            self.k = k
            print(f"episode {k}")

            q = self.update_q(q)

            policy = self.q_to_policy(q)

            if self.adv_data is not None:
                self.adv_data.get_adversarial_data(policy)

            traj = self.play_policy(policy)
            self.update_model(traj)
            self.compute_estimates()

            q_true = self.compute_true_q(policy)
            q_marg = self.get_marginal_q(q_true)
            self.played_q_list.append(q)
            self.true_q_list.append(q_true)

            incremental_reward = 0.0
            for (x, a), q_val in q_marg.items():
                if (x, a, k) not in self.reward:
                    self.unseen_reward[(x, a, k)] = self.Cmdp.get_reward(x, a, k)
                if self.Cmdp.reward_type == "stochastic":
                    incremental_reward += q_val * self.Cmdp.reward_mean(x, a)
                else:
                    incremental_reward += q_val * self.reward.get((x, a, k), self.unseen_reward.get((x, a, k)))

            self.cumul_reward += incremental_reward
            self.cumul_reward_list.append(self.cumul_reward)

            self.compute_cumul_viol(q_true, traj)
            self.viol_list.append(np.max(self.violation))

        self.compute_regret()
