import numpy as np
import pdb
import sys
from ARMUL import ARMUL
from preprocessing import split

class experiment_synthetic: # for synthetic data experiments
    def __init__(self,  n = 10, m = 1000, d = 20):
        self.n, self.m, self.d = n, m, d
        
    def generate_items(self, num_items, d):
        factor = 1 / np.sqrt(2)  # 
        x = np.random.normal(0, 1, (num_items, d - 1))
        norms = np.linalg.norm(x, axis=1, keepdims=True)
        x = x / norms * factor  # 
        ones_column = np.ones((num_items, 1)) * factor  # 
        return np.array(np.hstack((x, ones_column)))  #  x  ones_column
    

    def generate_subgaussian_noise(self, shape, scale=0.05):
        """
        Generate sub-Gaussian noise of specified dimensions
        shape: Shape of noise, such as (n,) or (n, 1)
        scale: Standard deviation of noise
        """
        return np.random.normal(loc=0, scale=scale, size=shape)
    
    def getsamples(self, setting, K = 3, signal_norm = 2, sigma = 1, delta = 0.1, epsilon = 0, norm_outliers = 2, seed = 1000, offline_method = 'random', policy = 'uniform'):
        assert setting == 'vanilla' or setting == 'clustered' or setting == 'lowrank'
        assert K <= self.d
        
        self.setting, self.K, self.seed = setting, K, seed
        
        np.random.seed(seed)
        data_X, data_y = [], []
        Theta = np.zeros((self.d, self.m))

        Clustered_Theta = self.generate_items(K, self.d).T
        if setting == 'vanilla':
            Theta[0, :] = signal_norm
        elif setting == 'clustered':
            if policy == 'half':
                # Random cluster sizes: fixed number of clusters K and total users m, but random number of users in each cluster
                p0 = np.random.dirichlet(alpha = np.ones(K))
                counts = list(np.random.multinomial(self.m, p0))
                start = 0
                for k in range(K):
                    cnt = counts[k]
                    if cnt <= 0:
                        continue
                    end = start + cnt
                    Theta[:, start:end] = Clustered_Theta[:, k:k+1] * np.ones((1, cnt))
                    start = end
            else:
                r = int(self.m / K)
                for k in range(K):
                    if k < K - 1:
                        # All tasks within the same cluster use exactly the same parameters
                        Theta[:, (k * r):(k * r + r)] = Clustered_Theta[:, k:k+1] * np.ones((1, r))
                    else:
                        # Last cluster, handle remaining tasks
                        remaining_tasks = self.m - (K - 1) * r
                        Theta[:, (k * r):self.m] = Clustered_Theta[:, k:k+1] * np.ones((1, remaining_tasks))
        elif setting == 'lowrank':
            for j in range(self.m):
                z_j = np.random.randn(K)
                Theta[0:K, j] = signal_norm * z_j / np.linalg.norm(z_j)

        S_outliers = np.random.choice(self.m, size = int(self.m * epsilon), replace = False)
        S = np.delete(np.array(range(self.m)), S_outliers)
        S_outliers = set(S_outliers)
        S = set(S)
        
        
        def _beta(d, N, t):
            return np.sqrt(d * np.log(1 + N / max(d, 1)) + 4 * np.log(max(t, 2)) + np.log(2)) + 1

        for j in range(self.m):
            # offline data collection per task j
            if offline_method == 'linucb':
                S_j = np.eye(self.d)
                b_j = np.zeros(self.d)
                Sinv_j = np.eye(self.d)
                X_j_list, y_j_list = [], []
                for t in range(1, self.n + 1):
                    # generate candidate items per round
                    items_t = self.generate_items(self.n, self.d)
                    theta_hat = Sinv_j @ b_j
                    # UCB score
                    ucb_bonus = (items_t @ Sinv_j * items_t).sum(axis=1)
                    k_t = np.argmax(items_t @ theta_hat + _beta(self.d, t - 1, t) * ucb_bonus)
                    x_t = items_t[k_t]
                    y_t = float(x_t @ Theta[:, j] + np.random.normal(loc=0, scale=0.05))
                    # update stats
                    S_j += np.outer(x_t, x_t)
                    b_j += y_t * x_t
                    Sinv_j = np.linalg.inv(S_j)
                    # store sample
                    X_j_list.append(x_t)
                    y_j_list.append(y_t)
                data_X.append(np.vstack(X_j_list))
                data_y.append(np.array(y_j_list))
            else:
                # random sampling (original behavior)
                X_j = self.generate_items(self.n, self.d)
                data_X.append(X_j)
                y_j = (X_j @ Theta[:, j].reshape(-1, 1)).reshape(-1,) +  self.generate_subgaussian_noise((self.n,))
                data_y.append(y_j)
            
            # # Perform linear regression on X_j and y_j, no bias, output theta
            # theta_j = np.linalg.inv(X_j.T @ X_j) @ X_j.T @ y_j
            # print(theta_j[-1])

        # print(data_X[0])
        # print(data_y[0])
        # X_j = data_X[0]
        # y_j = data_y[0]
        # theta_j = np.linalg.inv(X_j.T @ X_j) @ X_j.T @ y_j
        # print(theta_j)

        self.data = [data_X, data_y]

        self.Theta = Theta
        self.S = S
        self.Sc = S_outliers 
    
    def run(self, lbd_list, n_fold = 2, eta = 0.05, T = 100):        
        # ARMUL
        mtl = ARMUL(link = 'linear', n_class = 1, penalty = 'ridge') 

        mtl.cv(self.data, lbd_list = lbd_list, model = self.setting, n_fold = n_fold, K = self.K, eta_global = eta, eta_local = eta, eta_B = eta, eta_Z = eta, T_global = T, seed = self.seed, standardization = False, intercept = False)
        Theta_hat = mtl.models[self.setting][:, :, 0].T

        self.err, self.err_S = dict(), dict()
        err_ARMUL, err_base = [], []
        for j in range(self.m):
            err_ARMUL.append( np.linalg.norm((Theta_hat - self.Theta)[:, j]) )
        self.err['ARMUL'] = err_ARMUL
        self.err_S['ARMUL'] = [err_ARMUL[j] for j in self.S]

        results = mtl.evaluate_suboptimal_gap(samples = self.n, model = self.setting, true_theta = self.Theta)
        return results

class experiment_real: # for real data experiments
    def __init__(self, n=10, m=1000, d=20, user_features_file=None, dataset_name='ml'):
        self.n, self.m, self.d = n, m, d
        self.user_features_file = user_features_file
        self.dataset_name = dataset_name

    def generate_items(self, num_items, d):
        factor = 1 / np.sqrt(2)  # 
        x = np.random.normal(0, 1, (num_items, d - 1))
        norms = np.linalg.norm(x, axis=1, keepdims=True)
        x = x / norms * factor  # 
        ones_column = np.ones((num_items, 1)) * factor  # 
        return np.array(np.hstack((x, ones_column)))  #  x  ones_column
    

    def generate_subgaussian_noise(self, shape, scale=0.05):
        """
        Generate sub-Gaussian noise of specified dimensions
        shape: Shape of noise, such as (n,) or (n, 1)
        scale: Standard deviation of noise
        """
        return np.random.normal(loc=0, scale=scale, size=shape)
    
    def getsamples(self, setting, K = 3, signal_norm = 2, sigma = 1, delta = 0.1, epsilon = 0, norm_outliers = 2, seed = 1000, dataset=None, offline_method = 'random', policy = 'uniform'):
        assert setting == 'vanilla' or setting == 'clustered' or setting == 'lowrank'
        assert K <= self.d
        
        self.setting, self.K, self.seed = setting, K, seed
        
        # If dataset parameter is passed, update dataset_name
        if dataset is not None:
            self.dataset_name = dataset
        
        # Note: For real data, policy parameter does not affect data loading, as Theta is read from preloaded files
        # The policy parameter is retained here only to maintain interface consistency with experiment_synthetic
        
        np.random.seed(seed)
        data_X, data_y = [], []
        Theta = np.zeros((self.d, self.m))

        Clustered_Theta = self.generate_items(K, self.d).T

        # Load corresponding user feature files based on dataset name
        if self.dataset_name == 'ml':
            self.user_features_file = '../NIPS/data_process/ml_1000user_d20.npy'
        elif self.dataset_name == 'yelp':
            self.user_features_file = '../NIPS/data_process/yelp_1000user_d20.npy'
        else:
            raise ValueError(f"Unknown dataset: {self.dataset_name}. Supported datasets: 'ml', 'yelp'")
        
        Theta = np.load(self.user_features_file).T


        S_outliers = np.random.choice(self.m, size = int(self.m * epsilon), replace = False)
        S = np.delete(np.array(range(self.m)), S_outliers)
        S_outliers = set(S_outliers)
        S = set(S)
        
        
        def _beta(d, N, t):
            return np.sqrt(d * np.log(1 + N / max(d, 1)) + 4 * np.log(max(t, 2)) + np.log(2)) + 1

        for j in range(self.m):
            if offline_method == 'linucb':
                S_j = np.eye(self.d)
                b_j = np.zeros(self.d)
                Sinv_j = np.eye(self.d)
                X_j_list, y_j_list = [], []
                for t in range(1, self.n + 1):
                    items_t = self.generate_items(self.n, self.d)
                    theta_hat = Sinv_j @ b_j
                    ucb_bonus = (items_t @ Sinv_j * items_t).sum(axis=1)
                    k_t = np.argmax(items_t @ theta_hat + _beta(self.d, t - 1, t) * ucb_bonus)
                    x_t = items_t[k_t]
                    y_t = float(x_t @ Theta[:, j] + np.random.normal(loc=0, scale=0.05))
                    S_j += np.outer(x_t, x_t)
                    b_j += y_t * x_t
                    Sinv_j = np.linalg.inv(S_j)
                    X_j_list.append(x_t)
                    y_j_list.append(y_t)
                data_X.append(np.vstack(X_j_list))
                data_y.append(np.array(y_j_list))
            else:
                # random sampling (original behavior)
                X_j = self.generate_items(self.n, self.d)
                data_X.append(X_j)
                y_j = (X_j @ Theta[:, j].reshape(-1, 1)).reshape(-1,) +  self.generate_subgaussian_noise((self.n,))
                data_y.append(y_j)
            

        self.data = [data_X, data_y]

        self.Theta = Theta
        self.S = S
        self.Sc = S_outliers 
    
    def run(self, lbd_list, n_fold = 2, eta = 0.05, T = 100):        
        # ARMUL
        mtl = ARMUL(link = 'linear', n_class = 1, penalty = 'ridge') 


        print(T)
        mtl.cv(self.data, lbd_list = lbd_list, model = self.setting, n_fold = n_fold, K = self.K, eta_global = eta, eta_local = eta, eta_B = eta, eta_Z = eta, T_global = T, T_local = T, seed = self.seed, standardization = False, intercept = False)
        Theta_hat = mtl.models[self.setting][:, :, 0].T

        self.err, self.err_S = dict(), dict()

        err_ARMUL, err_base = [], []
        for j in range(self.m):
            err_ARMUL.append( np.linalg.norm((Theta_hat - self.Theta)[:, j]) )

        self.err['ARMUL'] = err_ARMUL

        self.err_S['ARMUL'] = [err_ARMUL[j] for j in self.S]
        
        results = mtl.evaluate_suboptimal_gap(samples = self.n, model = self.setting, true_theta = self.Theta)

        return results

