#!/usr/bin/env python
# coding: utf-8

# In[1]:


from scipy.linalg import solve_continuous_are
import numpy as np


# In[3]:


class TSCE_MF:
    def __init__(self, id, N, d, mu_0, Sigma_0, delta_t, sigma_list, x_bar_list, Q_list, R_list, A, X_0_id):
        """
        Parameters:
        - id: from 0 -- N-1, id of the player
        - N: Number of players
        - d: Dimensionality of the parameter vector A_vec
        - mu_0: Initial prior mean (vector of shape [d])
        - Sigma_0: Initial prior covariance matrix (shape [d, d])
        - sigma: Noise covariance matrix (shape [d, d])
        - delta_t: Time step size for the discrete update
        - X_0_id: initial position only for player id
        """
        self.id = id
        self.N = N
        self.d = d
        self.k = [0 for _ in range(N)] # period for each player
        
        self.A = A #true parameter - vectorization
        self.A_k = np.zeros((d ** 2, 1))# updated parameter - vectorization
        self.Q = [np.array(Q_list[i]) for i in range(N)]
        self.R = [np.array(R_list[i]) for i in range(N)]
        self.x_bar = [np.array(x_bar_list[i]) for i in range(N)] 

        self.sigma = [np.array(sigma_list[i]) for i in range(N)]  # Different sigma^i for each player
        self.varsig = [0.5 * self.sigma[i] @ self.sigma[i].T for i in range(N)]  # Different varsigma^i for each player

        # Compute eta
        self.B = self.compute_global_B()
        self.p = self.compute_global_p()
        eta_flat = np.linalg.solve(self.B, self.p).reshape(-1, 1)
        self.eta = eta_flat.reshape(N, d, 1)

        self.Upsilon = [] # Placeholder for Upsilon^i
        for i in range(N):
            C_i = 0.5* A.reshape(d,d).T @ self.R[i] @ A.reshape(d,d) + self.Q[i][i * self.d:(i + 1) * self.d, i * self.d:(i + 1) * self.d]
            Upsilon_i = solve_continuous_are(np.zeros((d,d)), self.varsig[i], C_i, 2 * np.linalg.inv(self.R[i]))
            self.Upsilon.append(Upsilon_i) # Store Upsilon^i for each player
        
        
        self.mu_0 = mu_0.copy() # prior distribution mu^id
        self.Sigma_0 = Sigma_0.copy() # prior distribution Sigma^id

        
        self.delta_t = delta_t #time step
        self.T_k = 0  # last time period threshold for player id
        
        self.last_update_step = 0  # t_k^i/delta_t of the last update; here we use step for simplification, notice that it is not t_k^i
        self.Sigma_k = Sigma_0.copy()  # posterior distribution Sigma at time t_k^i
        self.mu_k = mu_0.copy()  # posterior distribution mu at time t_k^i
        self.k = 0
        
        self.integrand_cov = np.zeros((d**2, d**2))  # Accumulated covariance integral 
        self.integrand_cov_k = np.zeros((d**2, d**2))  # Accumulated covariance integral from t_k^i
        self.integrand_mean_k = np.zeros((d**2,1))  # Accumulated mean integral from t_k^i

        self.X0 = X_0_id.copy()

        self.llambda = 0  # the optimal cost lambda 
        self.regret = 0  # the regret

        self.regret_traj = []
        

    def compute_global_B(self):
        """
        Compute the global B matrix shared by all players.
        Returns:
        - B: Nd x Nd matrix
        """
        B = np.zeros((self.N * self.d, self.N * self.d))
        for i in range(self.N):
            Q_i = self.Q[i]  # Q^i is specific to player i
            for j in range(self.N):
                Q_ij = Q_i[i * self.d:(i + 1) * self.d, j * self.d:(j + 1) * self.d]
                if i == j:  # Diagonal block
                    minus = 0.5 * (self.A.reshape(self.d,self.d).T @ self.R[i] @ self.A.reshape(self.d,self.d))
                    B[i * self.d:(i + 1) * self.d, j * self.d:(j + 1) * self.d] += -Q_ij - minus
                else: # Off-diagonal block
                    B[i * self.d:(i + 1) * self.d, j * self.d:(j + 1) * self.d] += -Q_ij
        return B

    def compute_global_p(self):
        """
        Compute the global p vector shared by all players.
        Returns:
        - p: Nd x 1 vector
        """
        p = np.zeros((self.N * self.d, 1))
        for i in range(self.N):
            Q_i = self.Q[i]  
            x_bar_i = self.x_bar[i]  
            for j in range(self.N):
                Q_ij = Q_i[i * self.d:(i + 1) * self.d, j * self.d:(j + 1) * self.d]
                x_bar_ij = x_bar_i[j * self.d:(j + 1) * self.d]  # Use the j-th block of \bar{x}^i
                p[i * self.d:(i + 1) * self.d] += -Q_ij @ x_bar_ij
        return p

    def compute_Gi(self, X_t, A_k):
        """
        Compute Gi matrix as defined in the dynamics.
        """
        i = self.id
        return -self.varsig[i] @ self.Upsilon[i] @ X_t + self.varsig[i] @ self.Upsilon[i] @ self.eta[i] - self.compute_G(X_t) @ A_k

    def compute_G(self, X_t):
        """
        Compute G matrix as defined in the dynamics.
        """
        G = np.zeros((self.d, self.d ** 2))
        for i in range(self.d):
            G[i, i * self.d:(i + 1) * self.d] = X_t.flatten()
        return G

    def compute_X_next(self, X_t, A_k):
        """
        Compute the next state X_next based on the SDE.
        Parameters:
        - X_t: Current state of the player
        - A_k: Sampled action for the player
        Returns:
        - X_next: Next state of the player
        """
        i = self.id
        Gi = self.compute_Gi(X_t, A_k)
        G = self.compute_G(X_t)
        noise = np.random.multivariate_normal(np.zeros(self.d), self.delta_t * (self.sigma[i]) @ (self.sigma[i]).T).reshape(-1,1)
        X_next = X_t + (Gi + G @ self.A) * self.delta_t + noise 
        return X_next

    def update_alpha(self, X_t, A_k):
        """
        Update alpha_t^i for player i based on the formula.
        """
        i = self.id
        alpha = self.varsig[i] @ self.Upsilon[i] @ (X_t - self.eta[i]) + self.compute_G(X_t) @ A_k
        return alpha

    def should_update(self, step, X_t, X_next, A_k):
        """
        Check if the posterior should be updated for player i at time t and update integrals.
        Sigma_t can be updated through the whole integral (not regarding period k)
        mu_t only needs to be updated at the end of every period
        """
        id = self.id
        G = self.compute_G(X_t)
        Gi = self.compute_Gi(X_t, A_k)
        noise_term = self.sigma[id] @ self.sigma[id].T
        noise_term_inv = np.linalg.inv(noise_term)
        
        # Update Sigma_t^i dynamically
        self.integrand_cov += G.T @ noise_term_inv @ G * self.delta_t
        Sigma_t = np.linalg.inv(
            np.eye(self.d**2) + self.Sigma_0 @ self.integrand_cov
        ) @ self.Sigma_0

        # Update integrand for k-th period
        self.integrand_cov_k += G.T @ noise_term_inv @ G * self.delta_t
        residual = X_next - X_t - Gi * self.delta_t
        self.integrand_mean_k += G.T @ noise_term_inv @ residual

        det_Sigma_t = np.linalg.det(Sigma_t)
        if ((step - self.last_update_step) * self.delta_t >= 1) and (((step - self.last_update_step) * self.delta_t >= self.T_k + 1) or (det_Sigma_t < 0.5 *  np.linalg.det(self.Sigma_k))):
            self.mu_k_update = np.linalg.inv(np.eye(self.d**2) + self.Sigma_k @ self.integrand_cov_k) @ (
                self.mu_k + self.Sigma_k @ self.integrand_mean_k
            )
            self.mu_k = self.mu_k_update
            self.Sigma_k = Sigma_t

            self.T_k = (step - self.last_update_step)* self.delta_t
            self.last_update_step = step
            self.k = 1

            #Reset integrand for k-th period
            self.integrand_cov_k = 0
            self.integrand_mean_k = 0.

            A_k = self.sample_action()
        return A_k


    def sample_action(self):
        """
        Sample action from the posterior distribution for player i.
        """
        i = self.id
        max_attempts = 100
        for _ in range(max_attempts):
            A_vec = np.random.multivariate_normal(self.mu_k.flatten(), self.Sigma_k).reshape(-1, 1)
            A_hat = A_vec.reshape(self.d, self.d)  # vector to matrix

            # Compute the matrix to check: A - A_hat - varsigma * Upsilon
            M = self.A.reshape(self.d, self.d) - A_hat - self.varsig[i] @ self.Upsilon[i]
            eigvals = np.linalg.eigvals(M)

            if np.all(np.real(eigvals) < 0):
                return A_vec  # valid sample

        raise ValueError("Failed to sample a stable A_hat after max_attempts.")
        #return np.random.multivariate_normal(self.mu_k.flatten(), self.Sigma_k).reshape(-1, 1)

    def f_i(self, X_t, alpha_t):
        """
        Calculate F^i(X_t, alpha_t, m^{-i}) for player i. X_t is only player i's states, m^{-i} are other players' stationary distribution.
        Compute in a closed form:
            set eta_prime = [..., eta^{i-1}, X_t^i, eta^{i+1},...]
                Upsilon_prime = blockdiag( ..., (Upsilon^{i-1})^{-1}, 0, (Upsilon^{i+1})^{-1}, ...) (since i-th block is for fixed value X_t^i)
            then calculate \tilde{f}^i = tr(Q^i Upsilon_prime) + (eta_prime - bar{x}^i)^T Q^i (eta_prime - bar{x}^i)
        """
        i = self.id
        N = self.N
        d = self.d
        
        eta_copy = self.eta.copy()
        eta_copy[i] = X_t
        eta_prime = eta_copy.reshape(-1, 1)

        Upsilon_inverse = []
        for j in range(N):
            if j == i:
                Upsilon_inverse.append(np.zeros((d, d)))
            else:
                Upsilon_inverse.append(np.linalg.inv(self.Upsilon[j]))
        Upsilon_prime = np.block([
            [Upsilon_inverse[j] if j == k else np.zeros((d, d)) for k in range(N)]
            for j in range(N)
        ])


        diff = eta_prime - self.x_bar[i]  # (Nd x 1)
        f_tilde = float(diff.T @ self.Q[i] @ diff) + np.trace(self.Q[i] @ Upsilon_prime)  # Scalar value

        # Compute 0.5 * alpha_t.T @ R^i @ alpha_t
        control_cost = 0.5 * float(alpha_t.T @ self.R[i] @ alpha_t)  # Scalar value

        # Total cost
        return f_tilde + control_cost

    def compute_F0(self):
        i = self.id
        N = self.N
        d = self.d
        F0 = 0

        # Extract necessary variables
        Q_i = self.Q[i]  # Nd x Nd matrix
        x_bar = self.x_bar[i]  # Nd x 1
        eta = self.eta  # Shape (N, d, 1)
        Upsilon = self.Upsilon[i]  # Shape (d, d)

        # Split Q^i and x_bar into d x d blocks
        N, d = self.N, self.d
        Q_blocks = [[Q_i[j*d:(j+1)*d, k*d:(k+1)*d] for k in range(N)] for j in range(N)]
        x_blocks = [x_bar[j*d:(j+1)*d] for j in range(N)]

        # Compute each term in the formula
        # Term 1: (x_bar_i^i)^T Q_ii^i x_bar_i^i
        F0 += float(x_blocks[i].T @ Q_blocks[i][i] @ x_blocks[i])

        # Term 2: - (x_bar_i^i)^T (sum_{j != i} Q_ij^i (eta^j - x_bar_i^j))
        term2 = 0
        for j in range(N):
            if j != i:
                term2 += Q_blocks[i][j] @ (eta[j] - x_blocks[j])
        F0 -= float(x_blocks[i].T @ term2)

        # Term 3: - (sum_{j != i} (eta^j - x_bar_i^j)^T Q_ji^i) x_bar_i^i
        term3 = 0
        for j in range(N):
            if j != i:
                term3 += (eta[j] - x_blocks[j]).T @ Q_blocks[j][i]
        F0 -= float(term3 @ x_blocks[i])

        # Term 4: + sum_{j != i, k != i, j != k} (eta^j - x_bar_i^j)^T Q_jk^i (eta^k - x_bar_i^k)
        term4 = 0
        for j in range(N):
            for k in range(N):
                if j != i and k != i and j != k:
                    term4 += (eta[j] - x_blocks[j]).T @ Q_blocks[j][k] @ (eta[k] - x_blocks[k])
        F0 += float(term4)

        # Term 5: + sum_{j != i} (tr(Q_jj^i (Upsilon^j)^-1) + (eta^j - x_bar_i^j)^T Q_jj^i (eta^j - x_bar_i^j))
        term5 = 0
        for j in range(N):
            if j != i:
                trace_part = np.trace(Q_blocks[j][j] @ np.linalg.inv(Upsilon))
                quad_part = (eta[j] - x_blocks[j]).T @ Q_blocks[j][j] @ (eta[j] - x_blocks[j])
                term5 += trace_part + quad_part
        F0 += float(term5)

        return F0

    def compute_lambda(self):
        """
        Compute lambda^i (or J_theta^i) for player i based on the given formula.
         """
        # Extract eta^i and Upsilon^i
        i = self.id
        
        eta_i = self.eta[i] 
        Upsilon_i = self.Upsilon[i] 

        # Extract other parameters
        varsig_i = self.varsig[i]  # Shape (d, d)
        R_i = self.R[i]  # Shape (d, d)
        A = self.A.reshape(self.d, self.d)  # Reshape A to (d, d) if needed

        # Compute F^i_0
        F_i_0 = self.compute_F0()

        # Compute second term
        middle_term = Upsilon_i @ varsig_i @ R_i @ varsig_i @ Upsilon_i  # Shape (d, d)
        quadratic_term = 0.5 * float(eta_i.T @ middle_term @ eta_i)  # Scalar

        # Compute third term
        trace_1 = np.trace(varsig_i @ R_i @ varsig_i @ Upsilon_i)  # Scalar
        trace_2 = np.trace(varsig_i @ R_i @ A)  # Scalar

        # Final lambda^i
        lambda_i = F_i_0 - quadratic_term + trace_1 + trace_2
        return lambda_i

    def update_regret(self, X_t, alpha_t, delta_t):
        """
        Update the regret for player i at each time step.
        """
        f_i_value = self.f_i(X_t, alpha_t)  # Calculate F^i(X_t, alpha_t) 
        self.regret += f_i_value * delta_t - delta_t * self.llambda  # Discrete approximation
        self.regret_traj.append(self.regret)

    def run(self, max_step):
        """
        Run the TSDE-MF algorithm with discrete updates.
        """
        self.A_k = self.sample_action()
        self.llambda = self.compute_lambda()
        X_t = self.X0
        alpha_t = self.update_alpha(X_t, self.A_k)
        #self.update_regret(X_t, alpha_t, self.delta_t)
        
        for step in range(1, max_step + 1):
            X_next = self.compute_X_next(X_t, self.A_k)
            
            # Check stopping criteria and update posterior if necessary
            updated_A_k = self.should_update(step, X_t, X_next, self.A_k)
            #if not np.allclose(updated_A_k, self.A_k):
                #print(f"\n[Step {step}] Posterior updated:")
                #print("A:", self.A.flatten())
                #print("Updated A_k:", updated_A_k.flatten())
                #print("Mu_k:", self.mu_k.flatten())
                #print("Sigma_k (diag):", np.diag(self.Sigma_k))
                #print("regret:", self.regret)
            
            self.A_k = updated_A_k
                    
            # Update X and alpha dynamically
            X_t = X_next
            alpha_t = self.update_alpha(X_t, self.A_k)
            # Update regret
            self.update_regret(X_t, alpha_t, self.delta_t)


# In[4]:


import matplotlib.pyplot as plt
import matplotlib


def run_experiment(num_runs, max_step, delta_t, para_list):
    """
    Run TSCE-MF for player id=0 across multiple runs and return average regret and std.

    Parameters:
    - num_runs: number of independent runs
    - max_step: steps per run
    - delta_t: time step

    Returns:
    - T_values: time grid
    - regret_mean: mean regret across runs
    - regret_std: std of regret across runs
    """
    N = para_list[0]
    d = para_list[1]
    id = para_list[2]

    sigma_list = para_list[5]
    Q_list = para_list[7]
    R_list = para_list[8]
    x_bar_list = para_list[6]

    mu_0 = para_list[3]
    Sigma_0 = para_list[4]

    A_true = para_list[9]
    X_0_id = para_list[10]

    regret_matrix = []

    for run in range(num_runs):
        model = TSCE_MF(
            id=id,
            N=N,
            d=d,
            mu_0=mu_0,
            Sigma_0=Sigma_0,
            delta_t=delta_t,
            sigma_list=sigma_list,
            x_bar_list=x_bar_list,
            Q_list=Q_list,
            R_list=R_list,
            A=A_true,
            X_0_id=X_0_id
        )
        print(f"\nRunning for run = {run}")
        model.run(max_step)
        regret_matrix.append(model.regret_traj)

    regret_matrix = np.array(regret_matrix)
    regret_mean = regret_matrix.mean(axis=0)
    regret_std = 0.2 * regret_matrix.std(axis=0)
    T_values = np.arange(1, max_step + 1) * delta_t

    return T_values, regret_mean, regret_std


def plot_multiple_curves(runs_list, max_step, delta_t, d, para_list):
    """
    Plot R(T) and R(T)/sqrt(T + d sqrt(T)) vs T for multiple num_runs.

    Parameters:
    - runs_list: list of n (number of runs) to try, e.g., [10, 100, 200]
    - max_step: number of time steps
    - delta_t: time increment
    - d: dimension (for denominator term)
    """
    matplotlib.rcParams.update({'font.size': 10})
    fig, axs = plt.subplots(1, 2, figsize=(10, 3))

    for num_runs in runs_list:
        T_values, regret_mean, regret_std = run_experiment(num_runs, max_step, delta_t, para_list)

        label = f"$n = {num_runs}$"
        axs[0].plot(T_values, regret_mean, label=label)
        axs[0].fill_between(T_values, regret_mean - regret_std, regret_mean + regret_std, alpha=0.2)

        denom = np.sqrt(T_values * np.log(T_values))
        scaled_regret = regret_mean / denom
        scaled_std = regret_std / denom
        axs[1].plot(T_values, scaled_regret, label=label)
        axs[1].fill_between(T_values, scaled_regret - scaled_std, scaled_regret + scaled_std, alpha=0.2)

    axs[0].set_title(r"(a) $R(T)$ vs $T$")
    axs[0].set_xlabel(r"$T$")
    axs[0].set_ylabel(r"$R(T)$")
    axs[0].grid()
    axs[0].legend()

    axs[1].set_title(r"(b) $\frac{R(T)}{\sqrt{T \ \log(T)}}$ vs $T$")
    axs[1].set_xlabel(r"$T$")
    axs[1].set_ylabel(r"$\frac{R(T)}{\sqrt{T \ \log(T)}}$")
    axs[1].grid()
    axs[1].legend()

    plt.tight_layout()
    plt.show()


# In[9]:


N = 10
d = 2
id = 3

mu_0 = np.zeros((d ** 2, 1))
Sigma_0 = 0.5 * np.eye(d ** 2)

sigma_list = [0.5 * np.eye(d) + 0.05 * np.random.randn(d, d) for _ in range(N)]
x_bar_list = [np.random.randn(N * d, 1) for _ in range(N)]

epsilon = 0.05
Q_list = []
R_list = []
for i in range(N):
    Q_base = np.eye(N * d)
    Q_noise = epsilon * np.random.randn(N * d, N * d)
    Q_noise = (Q_noise + Q_noise.T) / 2 
    Q_list.append(Q_base + Q_noise)

    R_base = np.eye(d)
    R_noise = epsilon * np.random.randn(d, d)
    R_noise = (R_noise + R_noise.T) / 2
    R_list.append(R_base + R_noise)

A_true = np.array([[0.5], [0], [0], [0.5]])
X_0_id = np.array([[0], [0.5]])

para_list = [N, d, id, mu_0, Sigma_0, sigma_list, x_bar_list, Q_list, R_list, A_true, X_0_id]


# In[11]:


plot_multiple_curves(runs_list=[10, 50, 100], max_step=5000, delta_t=0.05, d=d, para_list=para_list)


# In[12]:


plot_multiple_curves(runs_list=[10, 20, 50], max_step=50000, delta_t=0.005, d=d, para_list=para_list)


# In[ ]:




