import torch
import random
from src.utils.utils import vector_to_state_dict, state_dict_to_vector
from src.utils.variables_and_paths import ALL_DATASETS

# The sets for evaluating the in- and out-of-distribution datasets.
DATASETS_TO_SAMPLE = ["MNIST", "FashionMNIST", "EMNIST", "SVHN", "RESISC45", "CIFAR10", "STL10", "Cars", "Food101", "Flowers102", "DTD", "GTSRB", "RenderedSST2", "OxfordIIITPet"]
ID_DATASETS = ["KMNIST", "CIFAR100", "EuroSAT"]
OOD_TEST_DATASETS = ["PCAM", "SUN397", "FER2013"]

def ho_gsvd_subspace_boosting(sigmas: torch.Tensor, k=1.0, boost_nonzero=True):
    """
    Boosts the singular values of a tensor by replacing the top k% of singular values with the maximum singular value.
    """
    top_k_values = int(k * sigmas.shape[0])
    boosted_sigmas = sigmas.clone()

    indices = torch.sort(sigmas, descending=True)[1][:top_k_values]

    # Get the top k singular values
    # TODO: clean this part up later
    top_20_values = int( * sigmas.shape[0])
    top_indices = torch.sort(sigmas, descending=True)[1][:top_20_values]

    if boost_nonzero:
        mask_top_k = torch.zeros_like(sigmas, dtype=torch.bool)
        mask_top_k[indices] = True

        # Boost only non-zero singular values.
        mask_nonzero = sigmas != 0
        boosted_sigmas[(mask_top_k & mask_nonzero)] = sigmas.max()
    else:
        # Defaults to subspace boosting with \beta=0.0
        boosted_sigmas[indices] = sigmas.max()
        # boosted_sigmas[top_indices] = 1.5 * sigmas.max()

    return boosted_sigmas

def truncate_singular_values(sigmas: torch.Tensor, svd_thresh=0.2, shared_fraction=0.5):
    """
    Truncates the singular values of a tensor by replacing the low energy and interfering singular values with zeros.
    Parameters:
        sigmas: Generalized singular values.
        svd_thresh: cumulative energy threshold for truncation.
        shared_fraction: minimum fraction of tasks to decide interference.
    """
    sigmas = torch.stack(sigmas, dim=0)
    idx_mask = torch.zeros_like(sigmas, dtype=torch.bool)

    # Loop over tasks and prune certain singular values.
    for i in range(sigmas.shape[0]):
        sigma = sigmas[i]**2
        total_sum = torch.sum(sigma)
        sorted_sigma, indices = torch.sort(sigma, descending=True)
        
        cumulative = torch.cumsum(sorted_sigma, dim=0)
        cumulative_ratio = cumulative / total_sum

        num_components_to_keep = torch.searchsorted(cumulative_ratio, svd_thresh, right=True) + 1

        idx_to_keep = indices[:num_components_to_keep]

        # Set the indices to keep in the mask
        idx_mask[i, idx_to_keep] = True

    idx_count = torch.sum(idx_mask, dim=0)
    # Keep only very important, shared subspaces or low interfering subspaces.
    # TODO: evaluate whether to contain 0 or not.
    final_idx = torch.where((idx_count >= 7) | (idx_count <= 2))[0]

    # Create a mask for columns to keep
    final_mask_cols = torch.zeros(sigmas.shape[1], dtype=torch.bool)
    final_mask_cols[final_idx] = True

    # Apply mask
    return sigmas * final_mask_cols.unsqueeze(0)

def randomized_truncation(sigmas: torch.Tensor, svd_thresh, init_seed: int):
    truncated_sigmas = torch.stack(sigmas, dim=0)
    num_tasks, len_sigmas = len(sigmas), sigmas[0].shape[0]

    for i in range(num_tasks):
        sigma = sigmas[i]
        # New seed per task, [init_seed, init_seed + num_tasks - 1]
        random.seed(init_seed + i)
        sample_size = int(svd_thresh * len_sigmas)

        indices = list(range(len_sigmas))
        indices_to_keep = random.sample(indices, sample_size)
        
        mask = torch.zeros_like(sigma)
        mask[indices_to_keep] = 1.0

        truncated_sigmas[i] = sigma * mask

    return truncated_sigmas

def abs_log_ratios(sigmas, epsilon = 1e-12):
    num_rows = sigmas.shape[0]
    log_ratio_mat = torch.zeros(num_rows, num_rows)

    for i in range(num_rows):
        for j in range(i+1, num_rows):
            # calculate log ratios
            log_ratios = torch.log(sigmas[i] / (sigmas[j] + epsilon))
            mean_log_ratios = torch.mean(torch.abs(log_ratios))

            log_ratio_mat[i, j] = mean_log_ratios
            log_ratio_mat[j, i] = mean_log_ratios
    
    return log_ratio_mat

def find_optimal_set(alignment_matrix, k):
    """
    Find the optimal set of models based on the alignment matrix.
    Parameters:
        alignment_matrix: The alignment matrix.
        k: The number of datasets to select.
    Returns:
        The optimal set of datasets.
    """
    N = alignment_matrix.shape[0]
    
    temp_alignment_matrix = alignment_matrix.clone()
    if N > 0: temp_alignment_matrix.fill_diagonal_(-float('inf')) # Ensure diagonal is not chosen

    # Return the index of the flattened matrix
    flattened_idx = torch.argmax(temp_alignment_matrix).item()
    row_idx = flattened_idx // N
    col_idx = flattened_idx % N

    selected_indices = sorted(list(set([row_idx, col_idx]))) # Handles N=1 if k=1 was bypassed
                                                       # but N>=2 is implied for k>=2 path.
    
    candidate_indices = [i for i in range(N) if i not in selected_indices]
    
    # 2. Iteratively add the remaining rows
    while len(selected_indices) < k and candidate_indices:

        # Current best candidate score. Minimum dissimilarity to the selected set.
        max_min_diss_to_selected_set = -float('inf')
        best_next_candidate_idx = -1

        # Tensor of currently selected indices for efficient slicing
        current_selected_tensor = torch.tensor(selected_indices)

        for cand_idx in candidate_indices:
            # Dissimilarities of the current candidate to all already selected items
            diss_values_to_selected = alignment_matrix[cand_idx, current_selected_tensor]
            
            min_diss_for_this_candidate = torch.min(diss_values_to_selected).item()
            
            if min_diss_for_this_candidate > max_min_diss_to_selected_set:
                max_min_diss_to_selected_set = min_diss_for_this_candidate
                best_next_candidate_idx = cand_idx
        
        if best_next_candidate_idx != -1:
            selected_indices.append(best_next_candidate_idx)
            candidate_indices.remove(best_next_candidate_idx)
        else:
            # No suitable candidate found (e.g., all remaining have -inf dissimilarity)
            break 
            
    return sorted(selected_indices)

def find_optimal_set_from_predefined(
    alignment_matrix: torch.Tensor, 
    k_total: int, # Total number of items desired in the final set
    predefined_indices: list, # Indices of models to include
    predefined_indices_to_select: list # Indices of models to select from
) -> list:
    """
    Selects a total of k_total models. Starts with predefined_indices
    and greedily adds more to reach k_total, maximizing dissimilarity.
    Assumes predefined_indices is correctly initialized and non-empty
    if k_total > len(predefined_indices).

    Args:
        alignment_matrix (torch.Tensor): An (N, N) symmetric dissimilarity matrix.
        k_total (int): The total number of models to select in the final set.
        predefined_indices (list): A list of indices that MUST be included and
                                   serves as the initial seed.
    Returns:
        list: A list of k_total (or fewer if N < k_total or not enough candidates)
              indices representing the selected models, sorted.
    """
    N = alignment_matrix.shape[0]

    selected_indices = sorted(list(set(idx for idx in predefined_indices)))

    # TODO: set the candidate indices to the list of possible candidates.
    # candidate_indices = [i for i in range(N) if i not in selected_indices]

    candidate_indices = sorted(predefined_indices_to_select)

    while len(selected_indices) < k_total and candidate_indices:
        max_min_diss_to_selected_set = -float('inf')
        best_next_candidate_idx = -1

        current_selected_tensor = torch.tensor(selected_indices)

        for cand_idx in candidate_indices: 
            diss_values_to_selected = alignment_matrix[cand_idx, current_selected_tensor]
            min_diss_for_this_candidate = torch.min(diss_values_to_selected).item()
            
            if min_diss_for_this_candidate > max_min_diss_to_selected_set:
                max_min_diss_to_selected_set = min_diss_for_this_candidate
                best_next_candidate_idx = cand_idx
        
        if best_next_candidate_idx != -1:
            selected_indices.append(best_next_candidate_idx)
            candidate_indices.remove(best_next_candidate_idx)
        else:
            break
            
    return sorted(selected_indices)

def map_datasets_to_index(datasets):
    dataset_indices = []
    
    for i, dataset in enumerate(datasets):
        idx = ALL_DATASETS.index(dataset)
        dataset_indices.append(idx)

    return dataset_indices

def ho_gsvd(tv_flat_checks, ptm_check, config):
    model_state_dicts = [vector_to_state_dict(x, ptm_check, remove_keys=[]) for x in tv_flat_checks]
    merged_flat_vector = tv_flat_checks.sum(dim=0)

    merged_state_dict = vector_to_state_dict(merged_flat_vector, ptm_check, remove_keys=[])

    keys_to_eval = [
        "attn.in_proj_weight", # Attention weight matrices
        "attn.out_proj.weight",
        "mlp.c_fc.weight",
        "mlp.c_proj.weight",
    ]

    # alignment matrix
    num_tasks = len(model_state_dicts)
    alignment_matrices = []

    for key, param in merged_state_dict.items():
        if any(i in key for i in keys_to_eval) and isinstance(param, torch.Tensor):
            print(f"Processing {key}")
            '''
            1. Compute S
            '''
            model_weights = [x[key] for x in model_state_dicts]

            A = [x.T@x for x in model_weights]

            # A_stacked = [x_1.T, x_2.T, ..., x_n.T].T. Refers to A from Kempf et al. (2022) paper.
            A_stacked = torch.cat(model_weights, dim=0)
            A_stacked_T_A = A_stacked.T @ A_stacked

            num_tasks = len(A)
            mat_shape = model_state_dicts[0][key].shape

            S = torch.zeros(mat_shape[1], mat_shape[1])

            for i in range(num_tasks):
                # S is an upper triangular matrix.
                for j in range(i+1, num_tasks):
                    # Regularization to prevent rank deficiency
                    D_i_pi = A[i] + config.method.pi * A_stacked_T_A
                    D_j_pi = A[j] + config.method.pi * A_stacked_T_A

                    first_term = D_i_pi @ torch.linalg.inv(D_j_pi)
                    second_term = D_j_pi @ torch.linalg.inv(D_i_pi)

                    S += first_term + second_term
            
            S = S / (num_tasks * (num_tasks - 1))

            '''
            2. Calculate eigendecomposition of S
            '''
            eigenvalues, V_shared = torch.linalg.eig(S)
            eigenvalues = eigenvalues.real.to(torch.float32) # The eigenvalues and eigenvectors are real, so we lose no information.
            V_shared = V_shared.real.to(torch.float32)

            '''
            3. Calculate the B matrices: V@B.T = A.T via B.T = V_inv @ A.T
            '''
            # TODO: consider using torch.linalg.inv
            B = [torch.linalg.lstsq(V_shared, x.T).solution.T for x in model_weights]

            '''
            4. Constructing The left singular value matrices and singular value matrix.
            '''
            sigmas = [torch.linalg.norm(b, dim=0) for b in B]
            U = [b/sigma for b, sigma in zip(B, sigmas)]


            '''
            4.b. Calculate the alignment matrix.
            '''
            stacked_sigmas = torch.stack(sigmas, dim=0)
            similarity_matrix = abs_log_ratios(stacked_sigmas)
            alignment_matrices.append(similarity_matrix)

            '''
            5. [Optional] Truncate the singular values and zero out the irrelevant singular vectors.
            '''
            # Example how to perform truncation. Not used by default:
            # truncated_sigmas = truncate_singular_values(sigmas, svd_thresh=0.125, shared_fraction=0.5)

            truncated_sigmas = torch.stack(sigmas, dim=0)

            col_mask = torch.any(truncated_sigmas.abs() > 1e-7, dim=0).float()

            V_truncated = V_shared * col_mask.unsqueeze(0)

            for i in range(num_tasks):
                # Remove columns that correspond to zero singular values.
                # Zero out irrelevant columns.
                col_mask = (truncated_sigmas[i].abs() > 1e-7).float()
                U[i] = U[i] * col_mask.unsqueeze(0)

            '''
            5. Orthonormalize the U and V matrices.
            '''
            reconstructed_model_weights = torch.zeros(mat_shape)  
            # Orthonormalize V matrix via Procrustes
            #### Method 1
            P_v, _, R_vh = torch.linalg.svd(V_truncated, full_matrices=False)
            V_ortho = P_v @ R_vh

            # Orthonormalize the mean of U matrices. Generalized Procrustes
            U_mean = torch.stack(U, dim=0).mean(dim=0)
            P_u, _, R_uh = torch.linalg.svd(U_mean, full_matrices=False)
            U_ortho = P_u @ R_uh

            '''
            6. Weight Reconstruction
            '''
            for i in range(num_tasks):
                boosted_sigmas = ho_gsvd_subspace_boosting(sigmas[i], k=config.method.k, boost_nonzero=False)
                reconstructed_i = (U_ortho * boosted_sigmas) @ V_ortho.T
                reconstructed_model_weights += reconstructed_i

            merged_state_dict[key] = reconstructed_model_weights

    # Calculate mean alignment matrix
    alignment_matrices_stacked = torch.stack(alignment_matrices, dim=0)
    # TODO: postprocess this later.
    mean_alignment_matrix = torch.mean(alignment_matrices_stacked, dim=0)

    return state_dict_to_vector(merged_state_dict)