import numpy as np
from collections import defaultdict
import cvxpy as cp
from scipy.optimize import linprog


class Algorithm:
    def __init__(self, Cmdp, T, m, adv_data=None, delta=0.01):
        self.upper_bounds = {}
        self.lower_bounds = {}
        self.Cmdp = Cmdp
        self.X = Cmdp.X
        self.A = Cmdp.A
        self.T = T
        self.L = Cmdp.L
        self.m = m
        self.gamma = np.sqrt(self.L * np.log((len(self.X) * len(self.A)) / delta) / (T * len(self.X) * len(self.A)))
        self.eta = self.gamma
        self.b = np.zeros(m)

        self.confidence_set = {}
        self.p_bar = defaultdict(float)
        self.eps = defaultdict(float)
        self.f = defaultdict(float)

        self.N_t = defaultdict(int)
        self.M_t = defaultdict(int)

        self.delta = delta

        self.loss = defaultdict(float)
        self.unseen_loss = defaultdict(float)
        self.est_loss = defaultdict(float)
        self.empirical_loss_mean = defaultdict(float)
        self.empirical_loss_sum = defaultdict(float)

        self.constraints = defaultdict(float)
        self.est_constraints = defaultdict(float)
        self.unseen_constr = defaultdict(float)

        self.policies = []

        self.cumul_reward_list = []
        self.opt_term = defaultdict(float)
        self.viol_list = []
        self.cumul_reward = 0.0
        self.t = 0
        self.regret = 0.0
        self.violation = np.zeros(self.m)

        self.q_var = None
        self.q_hat_flat = None
        self.loss_vector = None
        self.q_objective = None
        self.q_problem = None

        self.played_q_list = []
        self.true_q_list = []

        self.adv_data = adv_data

    def init_uniform_q(self):
        q = defaultdict(float)
        for k in range(self.L - 1):
            X_k = self.Cmdp.layers[k]
            X_k1 = self.Cmdp.layers[k+1]
            Z = len(X_k) * len(self.A) * len(X_k1)
            for x in X_k:
                for a in self.A:
                    for x_prime in X_k1:
                        q[(x, a, x_prime)] = 1.0 / Z
        return q

    def q_to_policy(self, q):
        policy = {}
        denom = defaultdict(float)
        for (x, a, x_prime), val in q.items():
            denom[x] += val
        for (x, a, x_prime), val in q.items():
            if denom[x] == 0:
                policy[(x, a)] = 1.0 / len(self.A)
            else:
                policy[(x, a)] = policy.get((x, a), 0.0) + val / denom[x]
        return policy

    def play_policy(self, policy):
        traj = []
        x = self.Cmdp.x0

        est_loss = defaultdict(float)
        for k in range(self.L - 1):
            probs = np.array([policy.get((x, a), 0.0) for a in self.A])
            probs = probs / probs.sum()

            a = np.random.choice(self.A, p=probs)
            x_next = self.Cmdp.get_next_state(x, a)
            loss = 1 - self.Cmdp.get_reward(x, a, self.t)
            self.loss[(x, a, self.t)] = loss

            u_t = self.compute_upperbound(x, a, policy)
            est_loss[(x, a)] = loss / (u_t + self.gamma)

            constraints = []
            for i in range(self.m):
                constraint_value = self.Cmdp.get_constraint(x, a, i, self.t)
                constraints.append(constraint_value)
                key = (x, a, i, self.t)
                if key not in self.constraints:
                    self.constraints[key] = []
                self.constraints[key] = constraint_value

            for i in range(self.m):
                g_i = constraints[i]
                beta = self.compute_beta(x, a, i)
                key = (x, a, i)
                self.est_constraints[key] = (1 - beta) * self.est_constraints.get(key, 0) + beta * g_i

            self.empirical_loss_sum[(x, a)] += np.clip(est_loss[(x, a)], 0.0, 1.0)

            traj.append((x, a, x_next, loss, constraints))
            x = x_next

        if self.Cmdp.reward_type == "adversarial":
            for k in range(self.L - 1):
                X_k = self.Cmdp.layers[k]
                for x_ in X_k:
                    for a_ in self.A:
                        if (x_, a_, self.t) not in self.loss:
                            self.unseen_loss[(x_, a_, self.t)] = 1 - self.Cmdp.get_reward(x_, a_, self.t)
                            if self.Cmdp.constraint_type == "adversarial":
                                for i in range(self.m):
                                    key = (x_, a_, i, self.t)
                                    self.unseen_constr[key] = self.Cmdp.get_constraint(x_, a_, i, self.t)

        self.est_loss = est_loss
        return traj

    def init_counters(self):
        for x in self.X:
            for a in self.A:
                self.N_t[(x, a)] = 0
                for x_prime in self.X:
                    self.M_t[(x, a, x_prime)] = 0

    def update_counters(self, traj):
        for x, a, x_next, _, _ in traj:
            self.N_t[(x, a)] += 1
            self.M_t[(x, a, x_next)] += 1

    def compute_Gamma(self, i):
        log_term = np.log((2 * self.m * (self.T ** 2) * len(self.X) * len(self.A)) / self.delta)
        threshold = 19 * self.L * len(self.X) * np.sqrt(2 * self.t * len(self.A) * log_term)

        violations_sum = sum(self.constraints.get((x, a, i, self.t), 0) for x in self.X for a in self.A)
        return min(max(0, violations_sum - threshold), threshold)

    def compute_beta(self, x, a, i):
        return (1 / max(1, self.N_t[(x, a)])) * (1 + self.compute_Gamma(i))

    def greedy(self, p_bar, eps, X_k):
        j_minus = 1
        n = len(X_k)
        j_plus = n
        sigma = sorted(X_k, key=lambda x: self.f[x])

        while j_minus < j_plus:
            x_minus = sigma[j_minus - 1]
            x_plus = sigma[j_plus - 1]

            delta_minus = min(p_bar[x_minus], eps[x_minus])
            delta_plus = min(1 - p_bar[x_plus], eps[x_plus])

            delta = min(delta_minus, delta_plus)

            p_bar[x_minus] -= delta
            p_bar[x_plus] += delta

            if delta_minus <= delta_plus:
                eps[x_plus] -= delta
                j_minus += 1
            else:
                eps[x_minus] -= delta
                j_plus -= 1

        return sum(p_bar[sigma[j]] * self.f[sigma[j]] for j in range(n))

    def compute_upperbound(self, x, a, policy):
        kx = self.Cmdp.state_to_layer[x]
        X_kx = self.Cmdp.layers[kx]

        for xt in X_kx:
            self.f[xt] = 1.0 if xt == x else 0.0

        for k in range(kx - 1, -1, -1):
            X_k = self.Cmdp.layers[k]
            X_k1 = self.Cmdp.layers[k + 1]
            for xt in X_k:
                val = 0.0
                for a_ in self.A:
                    p_bar = {x_next: self.p_bar.get((xt, a_, x_next), 0.0) for x_next in X_k1}
                    eps = {x_next: self.eps.get((xt, a_, x_next), 0.0) for x_next in X_k1}

                    p_bar_copy = {k: float(v) for k, v in p_bar.items()}
                    eps_copy = {k: float(v) for k, v in eps.items()}

                    val += policy.get((xt, a_)) * self.greedy(p_bar_copy, eps_copy, X_k1)
                self.f[xt] = val
        return policy.get((x, a), 0.0) * self.f.get(self.Cmdp.x0, 0.0)

    @staticmethod
    def unn_KL_div(q_1, q_2):
        return cp.sum(cp.kl_div(q_1, q_2)) - cp.sum(q_1 - q_2)

    def get_bonus(self, x, a):
        log_term = np.log((2 * self.m * len(self.X) * len(self.A) * self.T) / self.delta)
        return np.sqrt(2 * log_term / max(1, self.N_t[(x, a)]))

    def setup_q_problem(self):
        X, A = self.X, self.A
        self.q_var = {}
        self.q_hat = {}
        self.loss_vector = {}
        self.ci_constraints = []
        self.constraint_params = {}
        self.q_keys_ordered = []

        constraints = []

        for k in range(self.L - 1):
            X_k = self.Cmdp.layers[k]
            X_k1 = self.Cmdp.layers[k + 1]
            for x in X_k:
                for a in A:
                    for x_prime in X_k1:
                        key = (x, a, x_prime)
                        self.q_var[key] = cp.Variable(nonneg=True)
                        self.q_keys_ordered.append(key)

        for k in range(1, self.L - 1):
            X_k = self.Cmdp.layers[k]
            X_prev = self.Cmdp.layers[k - 1]
            X_next = self.Cmdp.layers[k + 1]
            for x in X_k:
                incoming = cp.sum([self.q_var[(x_prev, a, x)]
                                   for x_prev in X_prev for a in A if (x_prev, a, x) in self.q_var])
                outgoing = cp.sum([self.q_var[(x, a, x_next)]
                                   for x_next in X_next for a in A if (x, a, x_next) in self.q_var])
                constraints.append(incoming == outgoing)

        for k in range(self.L - 1):
            total = 0
            X_k = self.Cmdp.layers[k]
            X_k1 = self.Cmdp.layers[k + 1]
            for x in X_k:
                for a in A:
                    total_out = cp.sum([self.q_var[(x, a, x_prime)]
                                        for x_prime in X_k1 if (x, a, x_prime) in self.q_var])
                    for x_prime in X_k1:
                        if (x, a, x_prime) in self.q_var:
                            if len(X_k1) > 1:
                                lb_param = cp.Parameter(nonneg=True)
                                ub_param = cp.Parameter(nonneg=True)

                                self.lower_bounds[(x, a, x_prime)] = lb_param
                                self.upper_bounds[(x, a, x_prime)] = ub_param

                                constraints.append(self.q_var[(x, a, x_prime)] >= lb_param * total_out)
                                constraints.append(self.q_var[(x, a, x_prime)] <= ub_param * total_out)
                            total += self.q_var[(x, a, x_prime)]
            constraints.append(total == 1)

        for i in range(self.m):
            expr = 0
            for k in range(self.L - 1):
                X_k = self.Cmdp.layers[k]
                X_k1 = self.Cmdp.layers[k + 1]
                for x in X_k:
                    for a in A:
                        for x_prime in X_k1:
                            if (x, a, x_prime) in self.q_var:
                                param = cp.Parameter()
                                self.constraint_params[(i, x, a, x_prime)] = param
                                expr += param * self.q_var[(x, a, x_prime)]
            constraints.append(expr <= 0)

        q_var_list = [self.q_var[key] for key in self.q_keys_ordered]
        self.q_hat_flat = cp.Parameter(len(q_var_list))
        self.loss_vector_flat = cp.Parameter(len(q_var_list))
        q_flat = cp.hstack(q_var_list)
        loss_expr = cp.sum(cp.multiply(self.loss_vector_flat, q_flat))

        self.q_objective = cp.Minimize(self.eta * loss_expr + self.unn_KL_div(q_flat, self.q_hat_flat))
        self.q_problem = cp.Problem(self.q_objective, constraints)

    def update_q(self, q_hat):
        loss_list = []
        q_hat_list = []
        last_loss = self.est_loss

        for key in self.q_keys_ordered:
            x, a, x_prime = key
            q_hat_list.append(q_hat.get(key, 0.0))
            loss_list.append(last_loss.get((x, a), 0.0))

        self.q_hat_flat.value = np.array(q_hat_list)
        self.loss_vector_flat.value = np.array(loss_list)

        for (i, x, a, x_prime), param in self.constraint_params.items():
            est_c = self.est_constraints.get((x, a, i), 0.0)
            bonus = self.get_bonus(x, a)
            param.value = est_c - bonus

        for (x, a, x_prime), (lb_val, ub_val) in self.confidence_set.items():
            if (x, a, x_prime) in self.lower_bounds:
                self.lower_bounds[(x, a, x_prime)].value = lb_val
                self.upper_bounds[(x, a, x_prime)].value = ub_val

        self.q_problem.solve(solver=cp.MOSEK, verbose=False)

        if self.q_problem.status not in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE]:
            raise RuntimeError(f"CVXPY failed with status {self.q_problem.status}")

        q_next = {}
        for key in self.q_keys_ordered:
            val = self.q_var[key].value
            q_next[key] = val

        return q_next

    def init_confidence_set(self):
        self.confidence_set = {}

        for k in range(self.L - 1):
            X_k = self.Cmdp.layers[k]

            X_k1 = self.Cmdp.layers[k+1] if k < self.L - 1 else []

            for x in X_k:
                for a in self.A:
                    for x_prime in X_k1:
                        self.confidence_set[(x, a, x_prime)] = (0.0, 1.0)
                        p_bar = self.M_t[(x, a, x_prime)] / max(1, self.N_t[(x, a)])

                        eps = 0
                        self.p_bar[(x, a, x_prime)] = p_bar
                        self.eps[(x, a, x_prime)] = eps

    def update_confidence_set(self):
        P_t = {}
        for k in range(self.L - 1):
            X_k = self.Cmdp.layers[k]
            X_k1 = self.Cmdp.layers[k + 1]
            for x in X_k:
                for a in self.A:
                    for x_prime in X_k1:
                        p_bar = self.M_t[(x, a, x_prime)] / max(1, self.N_t[(x, a)])

                        log_term = np.log(self.T * len(self.X) * len(self.A) / self.delta)
                        eps = (
                            2 * np.sqrt(p_bar * log_term / max(1, self.N_t[(x, a)] - 1)) +
                            (14 * log_term) / (3 * max(1, self.N_t[(x, a)] - 1))
                        )

                        lower = max(0, p_bar - eps)
                        upper = min(1, p_bar + eps)

                        P_t[(x, a, x_prime)] = (lower, upper)
                        self.p_bar[(x, a, x_prime)] = p_bar
                        self.eps[(x, a, x_prime)] = eps

        self.confidence_set = P_t

    def compute_cumul_viol(self, q, traj):
        q_matrix = np.zeros((len(self.X), len(self.A)))
        for (x, a, x_p), val in q.items():
            q_matrix[x, a] += val

        for i in range(self.m):
            g_i = defaultdict(lambda: 0.0)
            for x, a, _, _, constraints in traj:
                g_i[(x, a)] = constraints[i]

            for k in range(self.L - 1):
                X_k = self.Cmdp.layers[k]
                for x in X_k:
                    for a in self.A:
                        if (x, a) not in g_i:
                            if (x, a, i, self.t) not in self.unseen_constr:
                                self.unseen_constr[(x, a, i, self.t)] = self.Cmdp.get_constraint(x, a, i, self.t)
                                g_i[(x, a)] = self.unseen_constr[(x, a, i, self.t)]

            g_i_matrix = np.zeros((len(self.X), len(self.A)))
            for (x, a), val in g_i.items():
                g_i_matrix[x, a] = val

            self.violation[i] += np.sum(q_matrix * g_i_matrix)

    def compute_true_q(self, policy):
        q = defaultdict(float)
        mu = defaultdict(float)

        x0 = self.Cmdp.x0
        mu[x0] = 1.0

        for k in range(self.L - 1):
            X_k = self.Cmdp.layers[k]
            X_k1 = self.Cmdp.layers[k + 1]
            mu_next = defaultdict(float)

            for x in X_k:
                for a in self.A:
                    pi_val = policy.get((x, a))
                    for x_prime in X_k1:
                        P_val = self.Cmdp.transitions.get(f"{x}_{a}", {}).get(str(x_prime))
                        q_val = mu[x] * pi_val * P_val
                        q[(x, a, x_prime)] += q_val
                        mu_next[x_prime] += q_val
            mu = mu_next
        return q

    def compute_OPT(self, total_reward):
        var_map = {}
        rev_var_map = []
        var_idx = 0

        for k in range(self.L - 1):
            X_k = self.Cmdp.layers[k]
            X_k1 = self.Cmdp.layers[k+1]
            for x in X_k:
                for a in self.A:
                    for x_prime in X_k1:
                        key = (x, a, x_prime)
                        var_map[key] = var_idx
                        rev_var_map.append(key)
                        var_idx += 1

        num_vars = var_idx
        A_eq_rows = []
        b_eq_vals = []
        A_ub_rows = []
        b_ub_vals = []

        for k in range(1, self.L - 1):
            X_k = self.Cmdp.layers[k]
            X_prev = self.Cmdp.layers[k - 1]
            X_next = self.Cmdp.layers[k + 1]

            for x in X_k:
                row = np.zeros(num_vars)
                for x_prev in X_prev:
                    for a in self.A:
                        key = (x_prev, a, x)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] += 1
                for x_next in X_next:
                    for a in self.A:
                        key = (x, a, x_next)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] -= 1
                A_eq_rows.append(row)
                b_eq_vals.append(0.0)

        for k in range(self.L - 1):
            row = np.zeros(num_vars)
            X_k = self.Cmdp.layers[k]
            X_k1 = self.Cmdp.layers[k + 1]
            for x in X_k:
                for a in self.A:
                    for x_prime in X_k1:
                        key = (x, a, x_prime)
                        if key in var_map:
                            idx = var_map[key]
                            row[idx] += 1
            A_eq_rows.append(row)
            b_eq_vals.append(1.0)

        for k in range(self.L - 1):
            X_k = self.Cmdp.layers[k]
            X_k1 = self.Cmdp.layers[k + 1]
            if len(X_k1) > 1:
                for x in X_k:
                    for a in self.A:
                        p_true = self.Cmdp.transitions.get(f"{x}_{a}")
                        for x_prime in X_k1:
                            if (x, a, x_prime) in var_map:
                                row = np.zeros(num_vars)
                                idx = var_map[(x, a, x_prime)]
                                row[idx] -= 1
                                for xp in X_k1:
                                    if (x, a, xp) in var_map:
                                        idx_sum = var_map[(x, a, xp)]
                                        row[idx_sum] += p_true.get(str(x_prime))
                                A_eq_rows.append(row)
                                b_eq_vals.append(0.0)

        if self.Cmdp.constraint_type == "stochastic":
            for i in range(self.m):
                row = np.zeros(num_vars)
                for k in range(self.L - 1):
                    X_k = self.Cmdp.layers[k]
                    for x in X_k:
                        for a in self.A:
                            mean_constraint = self.Cmdp.constraint_mean(x, a, i)
                            X_k1 = self.Cmdp.layers[k+1]
                            for x_prime in X_k1:
                                key = (x, a, x_prime)
                                if key in var_map:
                                    idx = var_map[key]
                                    row[idx] += mean_constraint
                A_ub_rows.append(row)
                b_ub_vals.append(0.0)

        bounds = [(0, None)] * num_vars

        A_eq = np.array(A_eq_rows) if A_eq_rows else None
        b_eq = np.array(b_eq_vals) if b_eq_vals else None
        A_ub = np.array(A_ub_rows) if A_ub_rows else None
        b_ub = np.array(b_ub_vals) if b_ub_vals else None

        c = np.zeros(num_vars)
        for k in range(self.L - 1):
            X_k = self.Cmdp.layers[k]
            X_k1 = self.Cmdp.layers[k+1]
            for x in X_k:
                for a in self.A:
                    reward_val = total_reward.get((x, a)) / self.T
                    for x_prime in X_k1:
                        key = (x, a, x_prime)
                        if key in var_map:
                            idx = var_map[key]
                            c[idx] = -reward_val

        result = linprog(c,
                         A_ub=A_ub, b_ub=b_ub,
                         A_eq=A_eq, b_eq=b_eq,
                         bounds=bounds,
                         method='highs',
                         options={'disp': False})

        q_opt = {}
        if result.success:
            optimal_q_vector = result.x
            for idx, val in enumerate(optimal_q_vector):
                key = rev_var_map[idx]
                q_opt[key] = max(0, val)
        else:
            print(f"Error: {result.message}")
        return q_opt

    def compute_rho(self):
        def neg_g(x, a, i, t):
            return - self.constraints.get((x, a, i, t), self.unseen_constr.get((x, a, i, t)))

        rho_min = 0.05
        rho = rho_min

        det_actions = {tuple(pair) for pair in self.Cmdp.det_actions}

        for (x, a) in det_actions:
            pair_min = float('inf')
            for t in range(self.T):
                for i in range(self.m):
                    val = neg_g(x, a, i, t)
                    pair_min = min(pair_min, val)
            rho = max(rho, pair_min)

        alpha = 1.0 / (1.0 + rho)
        return rho, alpha

    def compute_regret(self):
        total_reward = defaultdict(float)

        for t in range(self.T):
            for k in range(self.L - 1):
                X_k = self.Cmdp.layers[k]
                for x in X_k:
                    for a in self.A:
                        if (x, a, self.t) not in self.loss:
                            if (x, a, self.t) not in self.unseen_loss:
                                self.unseen_loss[(x, a, t)] = 1 - self.Cmdp.get_reward(x, a, t)

                        loss = self.loss.get((x, a, t), self.unseen_loss.get((x, a, t)))
                        total_reward[(x, a)] += (1 - loss)

        q_star = self.compute_OPT(total_reward)

        if self.Cmdp.constraint_type == "adversarial":
            rho, alpha = self.compute_rho()

        self.opt_term[0] = 0.0
        for t in range(self.T):
            for (x, a, xp), q_val in q_star.items():
                loss = self.loss.get((x, a, t), self.unseen_loss.get((x, a, t)))
                if self.Cmdp.constraint_type == "stochastic":
                    self.opt_term[t] += (q_val * (1 - loss))
                else:
                    self.opt_term[t] += (q_val * (1 - loss)) * alpha
            self.opt_term[t + 1] = self.opt_term[t]

        q = self.played_q_list[self.T - 1]
        true_q = self.true_q_list[self.T - 1]

    def step(self, q):
        policy = self.q_to_policy(q)
        traj = self.play_policy(policy)
        self.policies.append(policy)

        if self.adv_data is not None:
            self.adv_data.get_adversarial_data(policy)

        q_pi_curr = self.compute_true_q(policy)

        incremental_reward = 0.0

        for (x, a, xp), q_val in q_pi_curr.items():
            if (x, a, self.t) not in self.loss:
                if (x, a, self.t) not in self.unseen_loss:
                    self.unseen_loss[(x, a, self.t)] = 1 - self.Cmdp.get_reward(x, a, self.t)
            loss = self.loss.get((x, a, self.t), self.unseen_loss.get((x, a, self.t)))
            incremental_reward += q_val * (1 - loss)

        self.cumul_reward += incremental_reward
        self.cumul_reward_list.append(self.cumul_reward)

        self.compute_cumul_viol(q_pi_curr, traj)
        self.viol_list.append(np.max(self.violation))

        self.update_counters(traj)

        self.update_confidence_set()

        q_next = self.update_q(q)

        self.played_q_list.append(q)
        self.true_q_list.append(q_pi_curr)

        return q_next

    def run(self):
        print(f"-----ALGORITHM 1-----")
        q = self.init_uniform_q()
        self.init_counters()
        self.init_confidence_set()
        self.setup_q_problem()

        if self.adv_data is not None:
            self.adv_data.get_adversarial_data(self.q_to_policy(q))

        for t in range(self.T):
            self.t = t
            print(f"round {t}")
            q_next = self.step(q)
            q = q_next

        self.compute_regret()
