import wandb
import pandas as pd
from tqdm import tqdm
import numpy as np
import json
import os

# Load best configurations
best_configs_path = 'processed_results/best_configs.json'
with open(best_configs_path, 'r') as f:
    best_configs = json.load(f)

# Authenticate and connect to your wandb account
wandb.login()

# Define your experiment information
entity = "reznov3395-simon-fraser-university"
project = "normal_train"

# Create base directory for saving matrices
base_save_dir = './processed_results/bwt_matrices'
os.makedirs(base_save_dir, exist_ok=True)

# Connect to API and get runs
api = wandb.Api()
runs = api.runs(f"{entity}/{project}")

# Process each run
for wandb_run in tqdm(runs, desc="Processing runs"):
    print(f"Run Name: {wandb_run.name}")
    print(f"Run ID: {wandb_run.id}")
    print(f"Run URL: {wandb_run.url}")
    print(f"Run State: {wandb_run.state}")
    
    # Extract run configuration details
    benchmark = wandb_run.config['benchmark']
    baseline = wandb_run.config['agent_type'].split('Agent')[0]
    
    optimizer = wandb_run.config['optimizer']
    seed = wandb_run.config['seed']
    
    # Create directory for this benchmark if it doesn't exist
    if benchmark == 'continual_imagenet':
        continue
    elif benchmark == 'new_continual_imagenet':
        benchmark = 'continual_imagenet'

    if baseline == 'DeepFourier':
        baseline = 'DeepF'
    benchmark_dir = os.path.join(base_save_dir, benchmark)
    
    os.makedirs(benchmark_dir, exist_ok=True)
    
    # Check if this run matches a best configuration
    best_config = best_configs[benchmark].get(baseline, None)
    
    if best_config is not None and optimizer in best_config and "bt_train_matrix" in wandb_run.summary:
        try:
            # Get matrix info
            matrix_info = wandb_run.summary["bt_train_matrix"]
            file_path = matrix_info["path"]
            
            # Download the file
            file = wandb_run.file(file_path)
            file.download(replace=True)
            
            # Load the downloaded JSON file
            print(f"Processing matrix from {file_path}")
            with open(file_path, 'r') as f:
                table_data = json.load(f)
            
            # Convert to DataFrame
            df = pd.DataFrame(table_data['data'], columns=table_data['columns'])
            
            # Convert to numpy array
            np_array = df.values
            
            # Define the save path with meaningful name
            save_filename = f"{baseline}_seed{seed}.npy"
            save_path = os.path.join(benchmark_dir, save_filename)
            
            # Special handling for NeuroSync baseline
            should_save = True
            if baseline == 'NeuroSync' and os.path.exists(save_path):
                # Calculate mean of last row for current matrix
                current_last_row_mean = np.mean(np_array[-1, :])
                
                # Load existing matrix and calculate its last row mean
                existing_matrix = np.load(save_path)
                existing_last_row_mean = np.mean(existing_matrix[-1, :])
                
                # Only save if current mean is higher than existing
                if current_last_row_mean > existing_last_row_mean:
                    print(f"Found better NeuroSync matrix: {current_last_row_mean:.4f} > {existing_last_row_mean:.4f}")
                    should_save = True
                else:
                    print(f"Skipping NeuroSync matrix as existing has better score: {existing_last_row_mean:.4f} >= {current_last_row_mean:.4f}")
                    should_save = False
            
            # Save as numpy array if appropriate
            if should_save:
                np.save(save_path, np_array)
                print(f"Successfully saved matrix to {save_path}")
            
        except Exception as e:
            print(f"Error processing run {wandb_run.id}: {e}")
    else:
        print(f"Skipping run {wandb_run.id} - not matching best config or no BWT matrix")