from multiprocessing import Pool, cpu_count
import os, pickle, ast
import numpy as np, random, torch, gymnasium as gym
import gym_envs  # ensure custom namespace 'gfn_challenges' is registered in all processes
from nns import ContinuousForwardPolicy, DiscreteForwardPolicy
from tqdm import tqdm
from fastdtw import fastdtw

# step 1: for each environment collect 100K/10K samples for each model
# step 2: get the performance metrics provided by each environment
# step 3: get the diversity metrics if applicable
# Note: regularly store the results to avoid recomputation

def seed_all(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def _euclidean(a, b):
    a = np.asarray(a)
    b = np.asarray(b)
    return float(np.linalg.norm(a - b))

def dtw_distance(a, b):
    """
    Compute the (symmetric) Dynamic Time Warping distance
    between two trajectories a and b.
    a, b: arrays of shape (T, d) or (T,) for 1-D.
    """
    dist, _ = fastdtw(a, b, dist=_euclidean)
    return dist

def _compute_pair_dtw(args):
    seq1, seq2 = args
    return dtw_distance(seq1, seq2)

def continuous_set_diversity(sequences, downsample_factor=10, n_random_samples=5):
    """
    Mean pairwise DTW distance over all unordered pairs.
    Returns 0.0 if fewer than 2 sequences.
    Performs random sampling of sequences (not within sequences) n_random_samples times 
    and averages the results to reduce stochasticity.
    """
    n = len(sequences)
    if n < 2:
        return 0.0
    
    if downsample_factor <= 1 or n <= downsample_factor:
        # No downsampling - use all sequences
        pairs_to_process = []
        for i in range(n):
            for j in range(i+1, n):
                pairs_to_process.append((sequences[i], sequences[j]))
        
        dists = []
        with Pool(processes=cpu_count()-1) as pool:
            dists = list(pool.map(_compute_pair_dtw, pairs_to_process))
        return np.mean(dists)
    
    # Perform random sampling of sequences multiple times and average the results
    diversity_scores = []
    
    # Calculate how many sequences to sample
    n_sample = max(2, n // downsample_factor)  # At least 2 sequences needed
    
    for _ in range(n_random_samples):
        # Randomly sample a subset of sequences
        sampled_indices = np.random.choice(n, size=n_sample, replace=False)
        sampled_sequences = [sequences[i] for i in sampled_indices]
        
        # Compute pairwise DTW distances for sampled sequences
        pairs_to_process = []
        for i in range(len(sampled_sequences)):
            for j in range(i+1, len(sampled_sequences)):
                pairs_to_process.append((sampled_sequences[i], sampled_sequences[j]))
        
        dists = []
        with Pool(processes=cpu_count()-1) as pool:
            dists = list(pool.map(_compute_pair_dtw, pairs_to_process))
        
        diversity_scores.append(np.mean(dists))

        # print(diversity_scores)
    
    # Return average diversity across all random samples
    return np.mean(diversity_scores)

def discrete_set_diversity(end_states):
    """Count the number of unique states in end_states."""
    if len(end_states) == 0:
        return 0
    # Convert each state to a tuple to make it hashable, then use set to get unique states
    unique_states = set()
    for state in end_states:
        if hasattr(state, '__iter__') and not isinstance(state, str):
            # If state is iterable (array-like), convert to tuple
            unique_states.add(tuple(np.asarray(state).flatten()))
        else:
            # If state is a scalar, add directly
            unique_states.add(state)
    return len(unique_states)

def _extract_nn_hidden_from_name(model_dir):
    """Lightweight parser to avoid importing analysis_utils in subprocesses."""
    base = os.path.basename(model_dir)
    parts = base.split('_')
    # method, total_iterations, nn_hidden_sizes, ...
    if len(parts) >= 3:
        return parts[2]
    return "[256, 256]"

def evaluate_seed(args):
    """
    Unified evaluation function for both continuous and discrete environments.
    Uses parallel data collection for efficiency.
    """
    model_dir, seed_index, num_of_samples, env_name= args
    # each worker seeds independently
    seed_all(42 + seed_index)
    # parse params & build env & policy
    env = gym.make(f'gfn_challenges/{env_name}')
    print(model_dir)
    hidden = ast.literal_eval(_extract_nn_hidden_from_name(model_dir))
    # Run inference on CPU to avoid device mismatch with tensors created inside the policy
    device = torch.device('cpu')

    if env_name in ['gmm-hard', 'pusher-simple']:
        policy = ContinuousForwardPolicy(env.observation_space.shape[0],
                                        env.action_space.shape[0]-1,
                                        hidden,
                                        torch.nn.LeakyReLU).to(device)
    else:
         
        policy = DiscreteForwardPolicy(env.observation_space.shape[0], env.action_space.n, hidden, torch.nn.LeakyReLU).to(device)
    sd = torch.load(os.path.join(model_dir,'forward_policy.pth'),
                    map_location=device, weights_only=True)
    policy.load_state_dict(sd); policy.eval()

    # Use parallel data collection similar to collect_rollouts_parallel in gfn_base.py
    # Create vectorized environment for parallel collection
    num_envs = 20
    data_env = gym.make_vec(f'gfn_challenges/{env_name}', num_envs=num_envs, vectorization_mode='sync')
    
    rounds = num_of_samples // num_envs
    total_samples = rounds * num_envs
    
    # Preallocate arrays for better performance
    act_seqs = [None] * total_samples
    rews = np.zeros(total_samples, dtype=np.float32)
    end_states = [None] * total_samples
    
    sample_idx = 0  # Track current sample index
    
    for round in range(rounds):
        print(round)
        dones = [False] * num_envs
        obs, _ = data_env.reset()
        traj = [[] for _ in range(num_envs)]
        end_state = [None] * num_envs
        rew = [0] * num_envs
        ep_t = 0
        
        with torch.no_grad():
            while not all(dones):
                if env_name == 'pusher-simple':
                    actions = policy(obs, np.array(ep_t >= env.unwrapped.max_t - 1) * np.ones(num_envs) * (1 - np.array(dones)),
                                   use_mask=True, epsilon=0.).detach().cpu().numpy()
                elif env_name == 'gmm-hard':
                    actions = policy(obs, np.array(ep_t >= env.unwrapped.max_t - 1) * np.ones(num_envs) * (1 - np.array(dones)),
                                   use_mask=False, epsilon=0.).detach().cpu().numpy()
                else:
                    actions = policy(obs, env.unwrapped.get_forward_action_masks(obs), epsilon=0.).detach().cpu().numpy()
                
                obs, rewards, d, _, _ = data_env.step(actions)
                
                for i in range(num_envs):
                    if not dones[i]:
                        if env_name not in ['pusher-simple', 'gmm-hard']:
                            traj[i].append(int(actions[i]))
                        else:
                            traj[i].append(actions[i][1:])
                        
                        if d[i]:
                            dones[i] = True
                            end_state[i] = obs[i]
                            rew[i] = rewards[i]
                
                ep_t += 1
        
        # Store results using preallocated arrays
        for i in range(num_envs):
            if env_name not in ['pusher-simple', 'gmm-hard']:
                act_seqs[sample_idx] = traj[i]
            else:
                act_seqs[sample_idx] = np.stack(traj[i]) if len(traj[i]) > 0 else np.array([])
            rews[sample_idx] = rew[i]
            end_states[sample_idx] = end_state[i]
            sample_idx += 1

    print("Data collection done")

    # Compute masks
    if env_name == "rna14":
        # using the is_mode function in the env
        masks = [
            env.unwrapped.is_mode(env.unwrapped.get_state(s))
            for s in end_states
        ]
    else:
        masks = rews > 1e-10
        if env_name == 'gmm-hard':
            masks = rews > 1e-3

    if env_name in ['gmm-hard', 'pusher-simple']:
        divs = continuous_set_diversity([s for s,ok in zip(act_seqs,masks) if ok])
    else:
        # Use end_states for discrete environments to count unique final states
        divs = discrete_set_diversity([s for s,ok in zip(end_states,masks) if ok])

    # Compute performance using environment's get_error function
    samples = env.unwrapped.get_state(np.array(end_states))
    
    performances = env.unwrapped.get_error(samples)

    return (performances, divs)

def main(input_dir, mode = "full"):
    # Lazy import heavy utilities in the main process only to avoid Windows spawn issues
    from notebooks.analysis_utils import dir_to_names, get_model_dict
    seed_all(42)

    if mode == "full":
        names = ["Ours", "GFN", "GAFN", "PBP-GFN", "GFN-RP", "SubTB", "Teacher"]
    else:
        names = ["BF + MP + TD", "BF + MP", "MP + TD", "MP", "TD", "Vanilla",
                   "BF + RP + TD", "BF + RP", "RP + TD", "BF + TD", "BF", "RP"]
        
    results = {}
    if os.path.exists(f"performance_{mode}.pkl"):
        with open(f"performance_{mode}.pkl","rb") as f:
            results = pickle.load(f)

    for env_name in ['hypergrid-mild', 'gmm-hard', 'molecular-generation', 'pusher-simple', 'hypergrid-hard']: # rna14
        model_list = dir_to_names(f"../{input_dir}/{env_name}")

        if mode == "full":
            model_dict, _ = get_model_dict(model_list)
        else:
            _, model_dict = get_model_dict(model_list)

        print(model_dict)

        if env_name in ['gmm-hard', 'pusher-simple']:
            num_samples = 10000
        else:
            num_samples = 100000
        
        # for each combination of (env_name, name), spin up a pool to eval its 5 seeds in parallel
        for name in names:
            print("Processing", env_name, name)
            # skip if results already exists
            if (env_name, name) in results:
                print(f"Results {env_name}, {name} already exists, skipping!")
                continue

            paths = model_dict[name]

            if(len(paths) != 5):
                # raise a warning and stop
                print(f"Warning: {env_name}, {name} does not have 5 seeds, please check!")
                return
            
            args = [(paths[i], i, num_samples, env_name) for i in range(len(paths))]
            
            # Use unified evaluation function for all environments
            out = []
            if env_name in ['gmm-hard', 'pusher-simple']:
                # For continuous environments, run sequentially to avoid memory issues
                for arg in args:
                    out.append(evaluate_seed(arg))
            else:
                # For discrete environments, can run in parallel
                with Pool(processes=5) as pool:
                    out = pool.map(evaluate_seed, args)

            # out is a list of (performances, div) for each seed
            performances_list, divs_list = zip(*out)
            performances = np.array(performances_list)
            divs = np.array(divs_list)

            results[(env_name, name)] = {
            'performances_mean' : performances.mean(axis=0),
            'performances_std' : performances.std(axis=0),
            'div_mean' : divs.mean(axis=0),
            'div_std'  : divs.std(axis=0),
            }

            print(name, results[(env_name,name)])
    

            with open(f"performance_{mode}.pkl","wb") as f:
                pickle.dump(results, f)



if __name__=="__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_dir", type=str, default="output_full", help="Input directory containing model results")
    parser.add_argument("--mode", type=str, default="full", choices=["full", "ablation"], help="Mode: 'full' for main results, 'ablation' for ablation study")
    args = parser.parse_args()
    main(args.input_dir, args.mode)
