import numpy as np
import torch
import random
from agent import *
from envs import *

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def collect_states_and_actions(ac, env, n_steps=10000, n_components=5):
    states = []
    
    # Collect states by running environment with random component selection
    with torch.no_grad():
        o, d, step_count = env.reset()[0], False, 0
        
        while step_count < n_steps:
            # Reset environment if episode ends
            if d:
                o, d = env.reset()[0], False
            
            # Random component selection for data collection
            k = random.randint(0, n_components - 1)
            
            # Get action from random component
            a, _, _, _, _, _, _ = ac.pi(
                torch.as_tensor(np.expand_dims(o, axis=0), dtype=torch.float32).to(device),
                manual_indices=torch.tensor([k]).to(device),
                deterministic=True,
            )
            a = a.cpu().numpy()[0]
            
            # Store current state
            states.append(o.copy())
            
            # Step environment
            o2, r, d, _, info = env.step(a)
            o = o2
            step_count += 1
    
    states = np.array(states)  # Shape: (n_steps, obs_dim)
    actions_by_component = np.zeros((n_components, n_steps, ac.pi.act_dim))
    
    with torch.no_grad():
        # Convert states to tensor once
        states_tensor = torch.as_tensor(states, dtype=torch.float32).to(device)
        
        for k in range(n_components):
            # Get actions for this component on all states
            a, _, _, _, _, _, _ = ac.pi(
                states_tensor,
                manual_indices=torch.full((n_steps,), k).to(device),
                deterministic=True,
            )
            actions_by_component[k] = a.cpu().numpy()
    
    return states, actions_by_component

def compute_pairwise_distances(actions_by_component):
    n_components = actions_by_component.shape[0]
    pairwise_distances = np.zeros((n_components, n_components))
    
    for i in range(n_components):
        for j in range(n_components):
            if i != j:
                # Compute mean Euclidean distance across all states
                distances = np.linalg.norm(
                    actions_by_component[i] - actions_by_component[j], 
                    axis=1
                )
                pairwise_distances[i, j] = np.mean(distances)
            else:
                pairwise_distances[i, j] = 0.0
    
    # Compute mean of upper triangle (avoid double counting)
    upper_triangle_mask = np.triu(np.ones((n_components, n_components)), k=1).astype(bool)
    mean_pairwise_distance = np.mean(pairwise_distances[upper_triangle_mask])
    
    return pairwise_distances, mean_pairwise_distance

def compute_zero_baseline_distances(actions_by_component):
    n_components = actions_by_component.shape[0]
    zero_distances = np.zeros(n_components)
    
    for k in range(n_components):
        # Compute distance from zero for each action
        distances = np.linalg.norm(actions_by_component[k], axis=1)
        zero_distances[k] = np.mean(distances)
    
    mean_zero_distance = np.mean(zero_distances)
    
    return zero_distances, mean_zero_distance

def main():
    # Parse arguments and setup
    args = read_args()
    setup_logging(args.log_level)
    set_seed(args.seed)
    
    print(f"Loading model from: {args.path}")
    print(f"Testing diversity with {args.n_components} components")
    
    # Load model and create environment
    ac = torch.load(args.path, map_location=device)
    env = build_env(args, render_mode='rgb_array')()
    
    # Collect data
    states, actions_by_component = collect_states_and_actions(
        ac, env, n_steps=10000, n_components=args.n_components
    )
    
    # Compute diversity metrics
    print("\nComputing diversity metrics...")
    
    pairwise_distances, mean_pairwise_distance = compute_pairwise_distances(actions_by_component)
    zero_distances, mean_zero_distance = compute_zero_baseline_distances(actions_by_component)
    
    print(f"\nPairwise Euclidean Distance (Static Diversity):")
    print(f"Mean pairwise distance: {mean_pairwise_distance:.4f}")
    
    print(f"\nDistance from Zero Baseline:")
    print(f"Mean distance from zero: {mean_zero_distance:.4f}")
    

if __name__ == "__main__":
    main()
    
    