import wandb
import pandas as pd
from tqdm import tqdm
import numpy as np
import json
import os


# Authenticate and connect to your wandb account
wandb.login()

# Define your experiment information
entity = "reznov3395-simon-fraser-university"
project = "new_train"

# Create base directory for saving matrices
base_save_dir = './processed_results/bwt_matrices_backward'
os.makedirs(base_save_dir, exist_ok=True)

# Connect to API and get runs
api = wandb.Api()
runs = api.runs(f"{entity}/{project}")

# Statistics tracking
total_runs = len(runs)
runs_with_train_backward = 0
runs_with_test_backward = 0
runs_with_both = 0
processed_agents = set()

# Process each run
for wandb_run in tqdm(runs, desc="Processing runs"):
    print(f"Run Name: {wandb_run.name}")
    print(f"Run ID: {wandb_run.id}")
    print(f"Run URL: {wandb_run.url}")
    print(f"Run State: {wandb_run.state}")
    
    # Get agent type from config
    agent = wandb_run.config['agent_type'].split('Agent')[0]
    benchmark = wandb_run.config['benchmark']
    print(f"Agent type: {agent}")
    processed_agents.add(agent)
    
    # Fetch artifacts
    # Get all artifacts logged to this run
    artifacts = wandb_run.logged_artifacts()
    
    # Dictionary to store the latest version of each artifact type
    latest_artifacts = {
        "train_backward": None,
        "test_backward": None
    }
    
    if 'imagenet' in benchmark:
        breakpoint()
        continue
    else:
        continue

    # Find artifacts with names ending in train_backward or test_backward
    for artifact in artifacts:
        artifact_name = artifact.name
        artifact_type = artifact.type
        artifact_version = artifact.version
        
        # Check if artifact name ends with train_backward or test_backward
        if "train_backward" in artifact_name:
            if (latest_artifacts["train_backward"] is None or 
                int(artifact.version) > int(latest_artifacts["train_backward"].version)):
                latest_artifacts["train_backward"] = artifact
                
        elif "test_backward" in artifact_name:
            if (latest_artifacts["test_backward"] is None or 
                int(artifact.version) > int(latest_artifacts["test_backward"].version)):
                latest_artifacts["test_backward"] = artifact

        
    # Track which artifact types were found for this run
    has_train_backward = latest_artifacts["train_backward"] is not None
    has_test_backward = latest_artifacts["test_backward"] is not None
    
    if has_train_backward:
        runs_with_train_backward += 1
    
    if has_test_backward:
        runs_with_test_backward += 1
    
    if has_train_backward and has_test_backward:
        runs_with_both += 1
    
    # Process the latest artifacts
    for artifact_type, artifact in latest_artifacts.items():
        if artifact is None:
            print(f"No {artifact_type} artifact found for run {wandb_run.name}")
            continue
        
        print(f"Processing {artifact_type} artifact: {artifact.name} v{artifact.version}")
        
        # Create directories for saving
        if artifact_type == "train_backward":
            save_dir = os.path.join(base_save_dir, benchmark,agent, "train")
        else:  # test_backward
            save_dir = os.path.join(base_save_dir, benchmark,agent, "test")
        
        os.makedirs(save_dir, exist_ok=True)
        
        try:
            # Download artifact to temporary location
            temp_dir = os.path.join("temp_artifacts", f"{wandb_run.id}_{artifact_type}")
            os.makedirs(temp_dir, exist_ok=True)
            artifact_dir = artifact.download(root=temp_dir)
            
            # Find and save .npy files
            npy_files_found = False
            for root, dirs, files in os.walk(artifact_dir):
                for file in files:
                    if file.endswith('.npy'):
                        npy_files_found = True
                        source_path = os.path.join(root, file)
                        dest_path = os.path.join(save_dir, f"{agent}.npy")
                        
                        # Copy the .npy file
                        import shutil
                        shutil.copy2(source_path, dest_path)
                        print(f"Saved {file} to {dest_path}")
            
            if not npy_files_found:
                print(f"No .npy files found in artifact {artifact.name}")
                
            # Clean up temporary directory
            import shutil
            shutil.rmtree(temp_dir, ignore_errors=True)
            
        except Exception as e:
            print(f"Error processing artifact {artifact.name}: {e}")
    
    print("-" * 50)

# Clean up the overall temp directory
import shutil
shutil.rmtree("temp_artifacts", ignore_errors=True)

# Print summary statistics
print("\n===== SUMMARY STATISTICS =====")
print(f"Total runs processed: {total_runs}")
print(f"Runs with train_backward artifacts: {runs_with_train_backward} ({runs_with_train_backward/total_runs*100:.1f}%)")
print(f"Runs with test_backward artifacts: {runs_with_test_backward} ({runs_with_test_backward/total_runs*100:.1f}%)")
print(f"Runs with both artifact types: {runs_with_both} ({runs_with_both/total_runs*100:.1f}%)")
print(f"Unique agent types processed: {len(processed_agents)}")
print(f"Agent types: {', '.join(sorted(processed_agents))}")
print("Processing complete!")