import numpy as np
from scipy.stats import norm
from MTL import MTL, baselines, prediction, evaluation
from preprocessing import split_cv
import pdb 

class ARMUL:
    def __init__(self, link = 'linear', n_class = 2, penalty = 'ridge'):
        # task j has n_j samples and an empirical risk f_j (sample average)
        # data = [X, y]
        # X: a list of m feature matrices, each of which is (n_j, d)
        # y: a list of m response vectors, each of which is (n_j, 1)
        # link = 'linear' or 'logistic'
        # n_class: number of classes in logistic regression, ignored for linear regression
        # class indices in y range from 0, 1, ..., n_class - 1
        # penalty: 'new' or 'ridge'
            # 'new': lbd_j * ||v|| on the j-th node
            # 'ridge': lbd_j * ||v||^2
        
        self.link = link      
        self.n_class = n_class
        self.penalty = penalty
        self.ucb_lambda = 0.5
        self.d=20
        self.nu = 100
        self.delta = 0.01
        self.models = dict()
    
    def vanilla(self, data, lbd = None, eta_global = 0.01, eta_local = 0.01, T_global = 1000, T_local = 1, standardization = True, intercept = True):
        # vanilla ARMUL
        # lbd: a list of m penalty parameters
        mtl = MTL(data, link = self.link, intercept = intercept, n_class = self.n_class, penalty = self.penalty, standardization = standardization)  
        self.X_means, self.X_stds, self.y_mean, self.y_std, self.n_list = mtl.X_means, mtl.X_stds, mtl.y_mean, mtl.y_std, mtl.n_list

        if lbd is None:
            lbd = 0.1 * np.sqrt(data[0][0].shape[1] / self.n_list)
        mtl.vanilla_train(lbd, eta_global, eta_local, T_global, T_local)
        # local models
        self.models['vanilla'] = mtl.models['vanilla']
        # global model
        self.models['vanilla_gamma'] = mtl.models['vanilla_gamma']


    def clustered(self, data, lbd = None, K = 10, eta_B = 0.01, eta_local = 0.01, T_global = 1000, T_B = 1, T_local = 1, standardization = False, intercept = True):
        # clustered ARMUL
        mtl = MTL(data, link = self.link, intercept = intercept, n_class = self.n_class, penalty = self.penalty, standardization = False)  
        self.X_means, self.X_stds, self.y_mean, self.y_std, self.n_list = mtl.X_means, mtl.X_stds, mtl.y_mean, mtl.y_std, mtl.n_list

        if lbd is None:
            lbd = 0.1 * np.sqrt(data[0][0].shape[1] / self.n_list)
        mtl.clustered_train(lbd = lbd, K = K, eta_B = eta_B, eta_local = eta_local, T_global = T_global, T_B = T_B, T_local = T_local)
        self.models['clustered'] = mtl.models['clustered']
        # print(len(self.models['clustered']))
        # print(self.models['clustered'][0])
        # exit()
        self.models['clustered_B'] = mtl.models['clustered_B']
        self.models['clustered_Z'] = mtl.models['clustered_Z']


    def lowrank(self, data, lbd = None, K = 1, eta_B = 0.01, eta_Z = 0.01, eta_local = 0.01, T_global = 100, T_B = 1, T_Z = 1, T_local = 1, standardization = True, intercept = True):
        # low-rank ARMUL
        mtl = MTL(data, link = self.link, intercept = intercept, n_class = self.n_class, penalty = self.penalty, standardization = standardization)          
        self.X_means, self.X_stds, self.y_mean, self.y_std, self.n_list = mtl.X_means, mtl.X_stds, mtl.y_mean, mtl.y_std, mtl.n_list

        if lbd is None:
            lbd = 0.1 * np.sqrt(data[0][0].shape[1] / self.n_list)
        mtl.lowrank_train(lbd = lbd, K = K, eta_B = eta_B, eta_Z = eta_Z, eta_local = eta_local, T_global = T_global, T_B = T_B, T_Z = T_Z, T_local = T_local)
        self.models['lowrank'] = mtl.models['lowrank']
        self.models['lowrank_B'] = mtl.models['lowrank_B']
        self.models['lowrank_Z'] = mtl.models['lowrank_Z']


    def predict(self, X_test, model = 'vanilla'):
        return prediction(X_test, self.models[model], self.link, self.X_means, self.X_stds, self.y_mean, self.y_std)


    def evaluate(self, data_test, model = 'vanilla'):
        y_pred = self.predict(data_test[0], model)
        return evaluation(y_pred, data_test[1], self.link)
    
    def generate_items(self, num_items, d):
        factor = 1 / np.sqrt(2)  # 
        x = np.random.normal(0, 1, (num_items, d - 1))
        norms = np.linalg.norm(x, axis=1, keepdims=True)
        x = x / norms * factor  # 
        ones_column = np.ones((num_items, 1)) * factor  # 
        return np.array(np.hstack((x, ones_column)))  #  x  ones_column


    def generate_subgaussian_noise(self, shape, scale=0.05):
        """
        Generate sub-Gaussian noise of specified dimensions
        shape: Shape of noise, such as (n,) or (n, 1)
        scale: Standard deviation of noise
        """
        return np.random.normal(loc=0, scale=scale, size=shape)
    
    def lcb_beta(self, N):
        return np.sqrt(self.d * np.log(1 + N / (self.d*self.ucb_lambda)) + 2 * np.log(2*self.nu/self.delta)) + np.sqrt(self.ucb_lambda)
    
    def _beta(self, N, t):
        return np.sqrt(self.d * np.log(1 + N / self.d) + 4 * np.log(t) + np.log(2)) + 1

    def _select_item_ucb(self, S, Sinv, theta, items, N, t):
        # Calculate reward estimate
        # print(theta)
        reward_estimate = np.dot(items, theta)
        # Ensure it's a one-dimensional array
        if reward_estimate.ndim > 1:
            reward_estimate = reward_estimate.flatten()
        
        # Calculate confidence interval width
        confidence_width = (np.matmul(items, Sinv) * items).sum(axis=1)
        
        # Calculate UCB scores
        ucb_scores = reward_estimate + self._beta(N, t) * confidence_width
        
        return np.argmax(ucb_scores)

    def test_recommend(self, i, theta, items, N, t, with_exploration=True):
        if with_exploration:
            return self._select_item_ucb(self.S[i], self.Sinv[i], theta, items, N, t)
        else:
            return np.argmax(np.dot(items, theta))

    def _select_item_lcb(self, S, Sinv, theta, items, N):
        dot_result = np.dot(items, theta)
        matmul_result = np.matmul(items, Sinv)
        element_wise = matmul_result * items
        sum_result = element_wise.sum(axis=1)
        
        # Ensure dot_result is a one-dimensional array
        if dot_result.ndim > 1:
            dot_result = dot_result.flatten()
        
        scores = dot_result - 0.1*self.lcb_beta(N) * sum_result
        return np.argmax(scores)

    # def test_recommend(self, i, theta, items, samples, with_exploration=True):

    #     if with_exploration:
    #         return self._select_item_lcb(self.S[i], self.Sinv[i], theta, items, samples)
    #     else:
    #         return np.argmax(np.dot(items, theta))

    def compute_mean_and_ci(self, data: np.ndarray, confidence: float = 0.999):
        """
        Calculate mean and confidence interval half-width (two-sided) based on sample data.
        The data parameter should be a 1D sample vector.
        """
        if data.ndim != 1:
            data = data.reshape(-1)

        if data.size < 2:
            # Too few samples to calculate standard deviation (ddof=1), return CI of 0
            return float(np.mean(data)), 0.0

        mean_value = float(np.mean(data))
        se = float(np.std(data, ddof=1) / np.sqrt(len(data)))
        ci_half_width = float(se * norm.ppf((1 + confidence) / 2))
        return mean_value, ci_half_width


            

    def evaluate_suboptimal_gap(self, samples, model='vanilla', true_theta=None, with_exploration=True):
        # Get current estimated theta
        if model == 'vanilla':
            estimated_theta = self.models['vanilla']
        elif model == 'clustered':
            estimated_theta = self.models['clustered']
        elif model == 'lowrank':
            estimated_theta = self.models['lowrank']
        else:
            raise ValueError(f"Unsupported model type: {model}")

        baseline_theta = true_theta
        m = len(estimated_theta)
        d = estimated_theta[0].shape[0]
        gaps = np.zeros((m, samples))

        for j in range(m):
            cluster_idx = self.cluster_indices[j]
            for i in range(samples):
                items = self.generate_items(20, d)
                # Calculate reward for each item
                item_rewards = np.dot(items, baseline_theta[:, j])
                optimal_reward = np.max(item_rewards)
                est_item_index = self.test_recommend(cluster_idx, estimated_theta[j], items, self.Ns[cluster_idx], m*samples//2, with_exploration)
                estimated_reward = item_rewards[est_item_index]  # Use real reward
                gap = optimal_reward - estimated_reward
                gaps[j, i] = gap

        # Calculate average of all gaps
        avg_gap = np.mean(gaps)
        std_gap = np.std(gaps)


            
        flat = gaps.reshape(-1)
        mean_value, ci_half_width = self.compute_mean_and_ci(flat, 0.999)
        result = {
            'gaps': gaps,
            'average_gap': avg_gap,
            'err': ci_half_width,
        }

        return result

    def process_data(self, data, model='clustered'):

        self._process_data_clustered(data)
        return data
    

    def _process_data_clustered(self, data):
        """
        Process data for clustered model, calculate shared S matrix for each cluster
        """
        # Get clustering information
        if 'clustered_Z' not in self.models:
            raise ValueError("Clustering model not trained, please call clustered method first")
        
        # Z matrix contains cluster assignment information
        Z = self.models['clustered_Z']  # shape: (nu, K)
        K = Z.shape[1]  # Number of clusters
        
        # Create S matrix for each cluster
        self.S_clusters = np.repeat(self.ucb_lambda * np.eye(self.d)[np.newaxis, :, :], K, axis=0)
        
        # Create cluster index mapping for each user
        self.cluster_indices = np.argmax(Z, axis=1)  # shape: (nu,)
        
        # Ensure cluster indices are within valid range
        if np.max(self.cluster_indices) >= K:
            print(f"Warning: cluster index out of range, max value: {np.max(self.cluster_indices)}, number of clusters: {K}")
            # Map out-of-range indices to the last cluster
            self.cluster_indices = np.clip(self.cluster_indices, 0, K-1)
        
        # Collect all item features for each cluster
        cluster_items = [[] for _ in range(K)]
        self.Ns = np.zeros(K)
        # Iterate through all users
        for j in range(len(data[0])):
            if j >= self.nu:
                break
            
            # Get user's cluster index
            cluster_idx = self.cluster_indices[j]
            
            # Get all item features for the j-th user
            user_items = data[0][j]  # Should be a matrix of (n_items, d)
            
            # Add all item features of this user to corresponding cluster
            for item_features in user_items:
                if len(item_features) == self.d:
                    cluster_items[cluster_idx].append(item_features)
        
        # Calculate S matrix for each cluster
        for k in range(K):
            self.Ns[k]+= len(cluster_items[k])
            if cluster_items[k]:  # If cluster is not empty
                items_array = np.array(cluster_items[k])  # shape: (n_items_in_cluster, d)
                # Calculate outer product for each item and accumulate to S_clusters[k]
                for item_features in items_array:
                    self.S_clusters[k] += np.outer(item_features, item_features)
        
        # Calculate inverse matrix for each cluster
        self.Sinv_clusters = np.array([np.linalg.inv(S_k) for S_k in self.S_clusters])
        
        # Create corresponding S and Sinv for each user (based on their cluster)
        self.S = np.array([self.S_clusters[self.cluster_indices[i]] for i in range(self.nu)])
        self.Sinv = np.array([self.Sinv_clusters[self.cluster_indices[i]] for i in range(self.nu)])


        print(f"Clustering processing completed:")
        print(f"  - Number of clusters: {K}")
        print(f"  - Number of users: {self.nu}")
        print(f"  - Number of users per cluster: {np.bincount(self.cluster_indices)}")
        print(f"  - Number of items per cluster: {self.Ns}")

    def get_clustered_items_format(self):
        """
        Get item format corresponding to clustering
        
        Returns:
            dict: Dictionary containing clustering information
                - 'cluster_indices': Cluster index for each user
                - 'cluster_counts': Number of users per cluster
                - 'S_clusters': S matrix for each cluster
                - 'Sinv_clusters': S inverse matrix for each cluster
        """
        if not hasattr(self, 'cluster_indices'):
            raise ValueError("Clustering model not initialized, please call process_data method first")
        
        return {
            'cluster_indices': self.cluster_indices,
            'cluster_counts': np.bincount(self.cluster_indices),
            'S_clusters': self.S_clusters,
            'Sinv_clusters': self.Sinv_clusters,
            'K': len(self.S_clusters)
        }
    
    def get_cluster_value_by_index(self, user_index):
        """
        Get corresponding cluster value by user index
        
        Args:
            user_index: User index
            
        Returns:
            dict: Dictionary containing cluster information for this user
                - 'cluster_idx': Cluster index
                - 'S': S matrix for this cluster
                - 'Sinv': S inverse matrix for this cluster
        """
        if not hasattr(self, 'cluster_indices'):
            raise ValueError("Clustering model not initialized, please call process_data method first")
        
        if user_index >= len(self.cluster_indices):
            raise ValueError(f"User index {user_index} out of range")
        
        cluster_idx = self.cluster_indices[user_index]
        
        return {
            'cluster_idx': cluster_idx,
            'S': self.S_clusters[cluster_idx],
            'Sinv': self.Sinv_clusters[cluster_idx]
        }
    
    def get_users_by_cluster(self, cluster_idx):
        """
        Get all user indices for specified cluster
        
        Args:
            cluster_idx: Cluster index
            
        Returns:
            list: List of user indices belonging to this cluster
        """
        if not hasattr(self, 'cluster_indices'):
            raise ValueError("Clustering model not initialized, please call process_data method first")
        
        return np.where(self.cluster_indices == cluster_idx)[0].tolist()
    


    def cv(self, data, lbd_list = None, model = 'vanilla', n_fold = 5, K = 2, eta_global = 0.01, eta_local = 0.01, eta_B = 0.01, eta_Z = 0.01, T_global = 1000, T_local = 100, T_B = 1, T_Z = 1, seed = 1000, standardization = True, intercept = True):
        # cross validation
        # lbd_list: a list of lambda-configurations
        np.random.seed(seed)
        m = len(data[0]) # number of tasks
        n_list = np.zeros(m).astype(int)
        for j in range(m):
            n_list[j] = len(data[0][j])
        splits = split_cv(n_list, n_fold, seed)
        
        L = len(lbd_list) # number of lambda configurations
        results = np.zeros((L, n_fold))
        for i in range(L):
            for k in range(n_fold):
                X_train, X_test = [], []
                y_train, y_test = [], []
                for j in range(m):
                    idx_test = splits[j][k]
                    X_test.append(data[0][j][idx_test])
                    y_test.append(data[1][j][idx_test])
                    idx_train = np.delete(np.array(range(n_list[j])), idx_test)
                    X_train.append(data[0][j][idx_train])
                    y_train.append(data[1][j][idx_train])
                
                data_train = [X_train, y_train]
                data_test = [X_test, y_test]

                # Perform linear regression on X_j and y_j, no bias, output theta
                # print(idx_train)
                # X_j = data[0][0][idx_train]
                # y_j = data[1][0][idx_train]
                # theta_j = np.linalg.inv(X_j.T @ X_j) @ X_j.T @ y_j
                # print(theta_j)
                # print(X_train[0])
                # print(y_train[0])
                # print(theta_j)
                # exit()
                if model == 'vanilla':
                    self.vanilla(data_train, lbd_list[i], eta_global, eta_local, T_global, T_local, standardization, intercept)
                if model == 'clustered':
                    self.clustered(data_train, lbd_list[i], K = K, eta_B = eta_B, eta_local = eta_local, T_global = T_global, T_B = T_B, T_local = T_local, standardization = standardization, intercept = intercept)
                if model == 'lowrank':
                    self.lowrank(data_train, lbd_list[i], K = K, eta_B = eta_B, eta_Z = eta_Z, eta_local = eta_local, T_global = T_global, T_B = T_B, T_Z = T_Z, T_local = T_local, standardization = standardization, intercept = intercept)

                tmp = self.evaluate(data_test, model = model)
                results[i, k] = tmp['average error']
        self.results = results
        cv_err = np.mean(results, axis = 1)
        

        # cv errors
        self.errors_cv = cv_err
        # hyperparameter selection
        idx = np.argmin(cv_err)
        self.lbd_cv = lbd_list[idx] # optimal lambda
        self.lbd_list = lbd_list # list of all lambda's

        # refitting
        if model == 'vanilla':
            self.vanilla(data, lbd_list[idx], eta_global, eta_local, T_global, T_local, standardization, intercept)
        if model == 'clustered':
            self.clustered(data, lbd_list[idx], K = K, eta_B = eta_B, eta_local = eta_local, T_global = T_global, T_B = T_B, T_local = T_local, standardization = standardization, intercept = intercept)
        if model == 'lowrank':
            self.lowrank(data, lbd_list[idx], K = K, eta_B = eta_B, eta_Z = eta_Z, eta_local = eta_local, T_global = T_global, T_B = T_B, T_Z = T_Z, T_local = T_local, standardization = standardization, intercept = intercept)

        # pdb.set_trace()
        self.process_data(data, model=model)

