# filename: codebase/tensor_decomposition.py
import torch
import numpy as np
import tensorly as tl
from tensorly.decomposition import tensor_train
from sklearn.linear_model import Ridge
from sklearn.multioutput import MultiOutputRegressor
from sklearn.metrics import mean_squared_error
import os
import collections  # Required for loading data from Step 1 if not implicitly handled by torch.load
from torch_geometric.data import Data  # Required for loading data from Step 1

# Configuration
FINAL_PROCESSED_DATA_PATH_INPUT = 'data/final_processed_data.pt'  # From Step 2
QITT_OUTPUT_PATH = 'data/qitt_processed_data.pt'
OUTPUT_DIR = 'data'

# TT Decomposition Configuration
# D_feat_combined from Step 2 was 74. Factors: 2 * 37
D_FEAT_COMBINED_FACTORS = (2, 37)

# Rank candidates for cross-validation
# r1 connects mode 0 (size 60) and mode 1 (size 2)
# r2 connects mode 1 (size 2) and mode 2 (size 37)
# Max r1 = 60, Max r2 = 37 (considering also the middle dim constraint)
# A smaller set for practical CV:
RANK_CANDIDATES_R1 = [2, 4, 6, 8]
RANK_CANDIDATES_R2 = [2, 4, 6, 8]


def reshape_tree_tensor(tensor_2d, max_n_sub, factors_d_feat):
    """
    Reshapes a 2D tree tensor (max_N_sub, D_feat_combined) into a 3D tensor.
    Args:
        tensor_2d (np.ndarray): The 2D tensor of shape (max_N_sub, D_feat_combined).
        max_n_sub (int): Number of substructures (first dimension).
        factors_d_feat (tuple): Tuple of factors for D_feat_combined (e.g., (f1, f2)).
    Returns:
        np.ndarray: The reshaped 3D tensor of shape (max_N_sub, f1, f2).
    """
    return tensor_2d.reshape(max_n_sub, factors_d_feat[0], factors_d_feat[1])


def perform_tt_decomposition_for_tensor(tensor_3d, tt_rank):
    """
    Performs TT decomposition on a single 3D tensor and returns concatenated cores.
    Args:
        tensor_3d (np.ndarray): The 3D tensor to decompose.
        tt_rank (tuple): The TT ranks (1, r1, r2, 1).
    Returns:
        np.ndarray: Flattened and concatenated TT-cores.
    """
    cores = tensor_train(tensor_3d, rank=tt_rank)
    # cores is a list of numpy arrays (TTTensor.factors)
    return np.concatenate([core.flatten() for core in cores])


def process_dataset_for_qitt(tensors_list, max_n_sub, factors_d_feat, tt_rank):
    """
    Processes a list of 2D tree tensors to generate QITT features.
    Args:
        tensors_list (list): List of PyTorch tensors, each (max_N_sub, D_feat_combined).
        max_n_sub (int): Number of substructures.
        factors_d_feat (tuple): Factors of D_feat_combined.
        tt_rank (tuple): The TT ranks (1, r1, r2, 1).
    Returns:
        np.ndarray: Array of QITT features, shape (num_samples, qitt_feature_dim).
    """
    qitt_features_list = []
    for tensor_pytorch in tensors_list:
        tensor_np_2d = tensor_pytorch.numpy()
        tensor_np_3d = reshape_tree_tensor(tensor_np_2d, max_n_sub, factors_d_feat)
        qitt_vec = perform_tt_decomposition_for_tensor(tensor_np_3d, tt_rank)
        qitt_features_list.append(qitt_vec)
    return np.array(qitt_features_list)


def main():
    """
    Main function for Step 3: Tensor Construction and QITT Decomposition.
    """
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)

    # Set TensorLy backend
    tl.set_backend("numpy")
    print("TensorLy backend set to: " + str(tl.get_backend()))

    print("Loading processed data from Step 2: " + str(FINAL_PROCESSED_DATA_PATH_INPUT))
    try:
        loaded_data = torch.load(FINAL_PROCESSED_DATA_PATH_INPUT, map_location='cpu')
    except FileNotFoundError:
        print("Error: Processed data file not found: " + str(FINAL_PROCESSED_DATA_PATH_INPUT))
        return
    except Exception as e:
        print("Error loading data: " + str(e))
        import traceback
        traceback.print_exc()
        return

    train_tensors_pytorch = loaded_data['train_tensors']
    val_tensors_pytorch = loaded_data['val_tensors']
    test_tensors_pytorch = loaded_data['test_tensors']

    train_labels_pytorch = torch.stack(loaded_data['train_labels']).squeeze(1).numpy()
    val_labels_pytorch = torch.stack(loaded_data['val_labels']).squeeze(1).numpy()
    test_labels_pytorch = torch.stack(loaded_data['test_labels']).squeeze(1).numpy()

    max_n_sub = loaded_data['max_N_sub']
    d_feat_combined = loaded_data['D_feat_combined']
    
    actual_factors_d_feat = (D_FEAT_COMBINED_FACTORS[0], D_FEAT_COMBINED_FACTORS[1])
    if actual_factors_d_feat[0] * actual_factors_d_feat[1] != d_feat_combined:
        raise ValueError("D_FEAT_COMBINED_FACTORS " + str(actual_factors_d_feat) + 
                         " do not multiply to D_feat_combined " + str(d_feat_combined))

    reshaped_dims = (max_n_sub, actual_factors_d_feat[0], actual_factors_d_feat[1])
    print("Original tensor shape per tree: (" + str(max_n_sub) + ", " + str(d_feat_combined) + ")")
    print("Reshaped tensor shape for TT decomposition: " + str(reshaped_dims))

    # --- TT-Rank Cross-Validation ---
    print("\nStarting TT-Rank cross-validation...")
    best_rmse_sum = float('inf')
    optimal_r1 = -1
    optimal_r2 = -1

    # Ridge regression model for CV
    # Alpha (regularization strength) for Ridge can also be tuned, but fix for simplicity here.
    ridge_model = MultiOutputRegressor(Ridge(alpha=1.0, random_state=42)) 

    print("Candidate r1 ranks: " + str(RANK_CANDIDATES_R1))
    print("Candidate r2 ranks: " + str(RANK_CANDIDATES_R2))

    for r1_cand in RANK_CANDIDATES_R1:
        for r2_cand in RANK_CANDIDATES_R2:
            current_tt_rank = (1, r1_cand, r2_cand, 1)
            print("  Testing ranks (r1, r2): (" + str(r1_cand) + ", " + str(r2_cand) + ")")

            # Process training and validation data with current ranks
            try:
                X_train_qitt = process_dataset_for_qitt(train_tensors_pytorch, max_n_sub, actual_factors_d_feat, current_tt_rank)
                X_val_qitt = process_dataset_for_qitt(val_tensors_pytorch, max_n_sub, actual_factors_d_feat, current_tt_rank)
            except Exception as e:
                print("    Error during TT decomposition or processing for rank (" + str(r1_cand) + "," + str(r2_cand) + "): " + str(e))
                # This can happen if ranks are too large for dimensions, though TensorLy usually handles this.
                # Or if a tensor is all zeros and decomposition fails.
                continue


            # Train model
            ridge_model.fit(X_train_qitt, train_labels_pytorch)

            # Evaluate model
            y_pred_val = ridge_model.predict(X_val_qitt)
            
            rmse_omega_m = np.sqrt(mean_squared_error(val_labels_pytorch[:, 0], y_pred_val[:, 0]))
            rmse_sigma_8 = np.sqrt(mean_squared_error(val_labels_pytorch[:, 1], y_pred_val[:, 1]))
            current_rmse_sum = rmse_omega_m + rmse_sigma_8
            
            print("    Validation RMSE (Omega_m): " + str(round(rmse_omega_m, 4)) + ", RMSE (sigma_8): " + str(round(rmse_sigma_8, 4)) + ", Sum RMSE: " + str(round(current_rmse_sum, 4)))

            if current_rmse_sum < best_rmse_sum:
                best_rmse_sum = current_rmse_sum
                optimal_r1 = r1_cand
                optimal_r2 = r2_cand
                print("    New best rank found: (" + str(optimal_r1) + ", " + str(optimal_r2) + ") with Sum RMSE: " + str(round(best_rmse_sum, 4)))

    if optimal_r1 == -1 or optimal_r2 == -1:
        print("Error: Optimal TT-ranks not found. CV might have failed for all candidates. Defaulting to smallest ranks.")
        optimal_r1 = RANK_CANDIDATES_R1[0]
        optimal_r2 = RANK_CANDIDATES_R2[0]
        
    optimal_tt_rank = (1, optimal_r1, optimal_r2, 1)
    print("\nOptimal TT-Rank (r1, r2) found: (" + str(optimal_r1) + ", " + str(optimal_r2) + ") with best Sum RMSE: " + str(round(best_rmse_sum, 4)))
    print("Final TT rank tuple for decomposition: " + str(optimal_tt_rank))

    # --- Generate final QITT features with optimal rank ---
    print("\nGenerating final QITT features using optimal rank...")
    train_qitt_features = process_dataset_for_qitt(train_tensors_pytorch, max_n_sub, actual_factors_d_feat, optimal_tt_rank)
    val_qitt_features = process_dataset_for_qitt(val_tensors_pytorch, max_n_sub, actual_factors_d_feat, optimal_tt_rank)
    test_qitt_features = process_dataset_for_qitt(test_tensors_pytorch, max_n_sub, actual_factors_d_feat, optimal_tt_rank)

    qitt_feature_dim = train_qitt_features.shape[1]
    print("QITT feature dimension: " + str(qitt_feature_dim))
    # Theoretical dimension: max_n_sub*r1 + r1*factor1*r2 + r2*factor2*1
    # = reshaped_dims[0]*r1_opt + r1_opt*reshaped_dims[1]*r2_opt + r2_opt*reshaped_dims[2]
    # = 60*r1_opt + 2*r1_opt*r2_opt + 37*r2_opt
    theoretical_dim = reshaped_dims[0]*optimal_r1 + \
                      optimal_r1*reshaped_dims[1]*optimal_r2 + \
                      optimal_r2*reshaped_dims[2]
    print("Theoretical QITT feature dimension based on optimal ranks: " + str(theoretical_dim))
    if qitt_feature_dim != theoretical_dim:
         print("Warning: Actual QITT feature dimension (" + str(qitt_feature_dim) + ") does not match theoretical (" + str(theoretical_dim) + "). Check logic.")


    # --- Save QITT processed data ---
    qitt_data_to_save = {
        'train_qitt_features': train_qitt_features,
        'val_qitt_features': val_qitt_features,
        'test_qitt_features': test_qitt_features,
        'train_labels': train_labels_pytorch,
        'val_labels': val_labels_pytorch,
        'test_labels': test_labels_pytorch,
        'optimal_tt_rank_r1_r2': (optimal_r1, optimal_r2),
        'optimal_tt_rank_full': optimal_tt_rank,
        'reshaped_tensor_dimensions': reshaped_dims,
        'qitt_feature_dimension': qitt_feature_dim,
        'd_feat_combined_factors': actual_factors_d_feat,
        'rank_cv_metric_sum_rmse': best_rmse_sum
    }

    # Add original metadata for completeness
    qitt_data_to_save['feature_means_original_nodes'] = loaded_data['feature_means_original_nodes']
    qitt_data_to_save['feature_stds_original_nodes'] = loaded_data['feature_stds_original_nodes']
    qitt_data_to_save['max_N_sub'] = max_n_sub
    qitt_data_to_save['D_feat_combined'] = d_feat_combined
    qitt_data_to_save['gnn_config'] = loaded_data['gnn_config']


    torch.save(qitt_data_to_save, QITT_OUTPUT_PATH)  # Save as torch, though features are numpy
    print("\nQITT processed data saved to: " + str(QITT_OUTPUT_PATH))
    print("Summary of saved QITT data contents:")
    for key, value in qitt_data_to_save.items():
        if isinstance(value, np.ndarray):
            print("  " + str(key) + ": numpy array of shape " + str(value.shape))
        elif isinstance(value, torch.Tensor):  # Should not be any new torch tensors here
            print("  " + str(key) + ": torch tensor of shape " + str(value.shape))
        else:
            print("  " + str(key) + ": " + str(value))
            
    print("\nNote on regularization: Tensor Train decomposition via `tensorly.decomposition.tensor_train` " +
          "is typically a direct method (e.g., SVD-based). Explicit regularization terms like " +
          "Tucker regularization are not directly applied in this specific decomposition function. " +
          "Complexity control is primarily achieved through TT-rank selection.")


if __name__ == '__main__':
    main()