import torch
import os
import pickle
import time
from pathlib import Path

import torch.distributed as dist
DISTRL_DIST_FILESYSTEM_GATHER = os.environ.get("DISTRL_DIST_FILESYSTEM_GATHER", None)
DISTRL_DEBUG_LOG_BEST_SAMPLES = os.environ.get("DISTRL_DEBUG_LOG_BEST_SAMPLES", None)
DISTRL_DEBUG_DETAIL_LOG = os.environ.get("DISTRL_DEBUG_DETAIL_LOG", None)

def _get_global_best_samples(args, state_dict, accelerator, count, is_ddp=True):
    """Get the globally best samples across all machines based on rewards.

    This function:
    1. Each process identifies its local top rewards
    2. Synchronize to find global threshold based on top N rewards
    3. Filter local samples to keep only those above threshold
    4. Redistribute filtered samples evenly across processes

    Args:
        args: Command line arguments
        state_dict: Dictionary containing rollout data including rewards
        accelerator: Accelerator for distributed training
        count: Current training iteration count
        is_ddp: Whether using DistributedDataParallel (multi-GPU) or not

    Returns:
        Original state_dict modified in-place to contain only high-reward samples
    """
    # Extract rewards from state_dict
    rewards = state_dict.get("final_reward", None)

    if rewards is None or len(rewards) == 0:
        accelerator.print("No rewards found in state_dict. Skipping sample selection.")
        return state_dict

    # 1. Get local rewards tensor
    local_rewards_tensor = rewards

    # 2. Gather all rewards from all processes to find global threshold
    if is_ddp and accelerator.num_processes > 1:
        # Ensure tensor is on the correct device before gathering
        local_rewards_tensor = local_rewards_tensor.to(accelerator.device)
        gathered_rewards = accelerator.gather(local_rewards_tensor)
        # Move to CPU after gathering
        gathered_rewards = gathered_rewards.cpu()
    else:
        # In single GPU mode, just use local rewards
        gathered_rewards = local_rewards_tensor.cpu()

    # Get unique values from gathered rewards
    gathered_rewards_unique = torch.tensor(sorted(set(gathered_rewards.tolist())), device=gathered_rewards.device)

    # Find the top N rewards to create threshold
    num_to_keep = min(args.num_best_samples, len(gathered_rewards))
    sorted_rewards, _ = torch.sort(gathered_rewards_unique, descending=True)

    # Set threshold as the Nth highest reward
    if num_to_keep > 0 and len(sorted_rewards) > 0:
        threshold = sorted_rewards[min(num_to_keep - 1, len(sorted_rewards) - 1)].item()
    else:
        # If no samples to keep, set an impossibly high threshold
        threshold = float('inf')

    # Log threshold information
    accelerator.print(f"Global reward threshold: {threshold}")

    # 3. Filter local samples based on global threshold
    keep_indices = []
    # Move local rewards to CPU for filtering
    local_rewards_tensor_cpu = local_rewards_tensor.cpu()
    for i, reward in enumerate(local_rewards_tensor_cpu):
        if reward >= threshold:
            keep_indices.append(i)

    # Log how many samples this process is keeping
    print(f"Process {accelerator.process_index}: Keeping {len(keep_indices)} samples")

    # Store filtered state dict items temporarily
    filtered_state_dict = {}

    # Extract all keys from the state_dict and keep only filtered indices
    for key in state_dict:
        if isinstance(state_dict[key], list):
            filtered_state_dict[key] = [state_dict[key][i] for i in keep_indices]
        elif isinstance(state_dict[key], torch.Tensor) and state_dict[key].dim() > 0 and state_dict[key].size(0) > 0:
            filtered_state_dict[key] = state_dict[key][keep_indices]
        else:
            # For any other type, keep as is
            filtered_state_dict[key] = state_dict[key]

    # 4. Redistribute samples evenly across processes
    # For single GPU case, we can just update the state dict and return
    if not is_ddp or accelerator.num_processes == 1:
        # Update state_dict in-place with filtered values
        for key in filtered_state_dict:
            state_dict[key] = filtered_state_dict[key]
        return state_dict

    # For multi-GPU, check if we should use filesystem-based sync
    if DISTRL_DIST_FILESYSTEM_GATHER:
        # Create directory for storing counts and state dicts
        dist_dir = Path(args.output_dir) / "dist"
        dist_dir.mkdir(exist_ok=True, parents=True)

        # First, save the local count to a file
        local_count = len(keep_indices)
        count_file = dist_dir / f"count_proc_{accelerator.process_index}.txt"
        with open(count_file, "w") as f:
            f.write(str(local_count))

        # Wait for all processes to save their count files
        all_files_present = False
        max_wait_time = 60  # seconds
        start_time = time.time()

        while not all_files_present and (time.time() - start_time < max_wait_time):
            all_files_present = True
            for i in range(accelerator.num_processes):
                if not (dist_dir / f"count_proc_{i}.txt").exists():
                    all_files_present = False
                    time.sleep(0.1)
                    break

        if not all_files_present:
            print(f"Warning: Not all count files were created within {max_wait_time} seconds. Proceeding anyway.")

        # Read all count files to get all process counts
        all_counts = []
        for i in range(accelerator.num_processes):
            count_file = dist_dir / f"count_proc_{i}.txt"
            if count_file.exists():
                with open(count_file, "r") as f:
                    try:
                        count = int(f.read().strip())
                        all_counts.append(count)
                    except (ValueError, IOError) as e:
                        print(f"Error reading count from process {i}: {e}")
                        all_counts.append(0)
            else:
                all_counts.append(0)

        # Calculate total samples
        total_samples = sum(all_counts)
        current_count = local_count

        # Log total samples kept across all processes (only main process logs to wandb)
        print(f"Process {accelerator.process_index}: Total samples kept across all processes: {total_samples}")

        # If no samples kept, return empty state_dict
        if total_samples == 0:
            accelerator.print("No samples above threshold. Skipping rejection sampling.")
            for key in list(state_dict.keys()):
                if isinstance(state_dict[key], list):
                    state_dict[key] = []
                elif isinstance(state_dict[key], torch.Tensor) and state_dict[key].dim() > 0:
                    state_dict[key] = state_dict[key].new_empty((0,) + state_dict[key].shape[1:])
            return state_dict

        # Calculate target count per process for even distribution
        samples_per_process = total_samples // accelerator.num_processes
        remainder = total_samples % accelerator.num_processes

        # Calculate target count for current process (add 1 if in the first 'remainder' processes)
        target_count = samples_per_process + (1 if accelerator.process_index < remainder else 0)

        # Get source processes (those with too many samples) and destination processes (those with too few)
        source_procs = []
        dest_procs = []
        for i, count in enumerate(all_counts):
            proc_target = samples_per_process + (1 if i < remainder else 0)
            if count > proc_target:
                source_procs.append((i, count - proc_target))  # (process_id, samples to give)
            elif count < proc_target:
                dest_procs.append((i, proc_target - count))  # (process_id, samples to receive)

        # Log redistribution plan
        if accelerator.is_main_process and DISTRL_DEBUG_DETAIL_LOG:
            plan_str = "Redistribution plan:\n"
            for src, count in source_procs:
                plan_str += f"  Process {src} gives {count} samples\n"
            for dst, count in dest_procs:
                plan_str += f"  Process {dst} receives {count} samples\n"
            accelerator.print(plan_str)

        # Save filtered state_dict to file
        proc_file = dist_dir / f"samples_proc_{accelerator.process_index}.pkl"
        with open(proc_file, "wb") as f:
            # Create serializable version of filtered_state_dict
            serializable_dict = {}
            for key, value in filtered_state_dict.items():
                if isinstance(value, torch.Tensor):
                    # Convert tensors to CPU for pickling
                    serializable_dict[key] = value.cpu()
                else:
                    serializable_dict[key] = value

            # Save to file
            pickle.dump((serializable_dict, current_count), f)

        # Wait for all processes to save their files
        all_files_present = False
        start_time = time.time()

        while not all_files_present and (time.time() - start_time < max_wait_time):
            all_files_present = True
            for i in range(accelerator.num_processes):
                if not (dist_dir / f"samples_proc_{i}.pkl").exists():
                    all_files_present = False
                    time.sleep(0.1)
                    break

        if not all_files_present:
            print(f"Warning: Not all sample files were created within {max_wait_time} seconds. Proceeding anyway.")

        accelerator.wait_for_everyone()

        # Each process reads metadata from all files (this is redundant but we already have the counts)
        all_files_info = []
        for proc_idx in range(accelerator.num_processes):
            proc_file = dist_dir / f"samples_proc_{proc_idx}.pkl"
            if proc_file.exists():
                with open(proc_file, "rb") as f:
                    _, proc_sample_count = pickle.load(f)
                    all_files_info.append((proc_idx, proc_sample_count))

        # Calculate assignments (each process runs the same deterministic algorithm)
        assignments = []
        sample_idx = 0

        for proc_idx in range(accelerator.num_processes):
            proc_target = samples_per_process + (1 if proc_idx < remainder else 0)
            proc_assignments = []

            # Assign samples from each source process
            samples_needed = proc_target
            for src_proc, src_count in all_files_info:
                src_start = 0
                # Skip samples already assigned to previous processes
                for p in range(proc_idx):
                    prev_assignments = assignments[p] if p < len(assignments) else []
                    for _, _, _, src_samples in prev_assignments:
                        if src_samples[0] == src_proc:
                            src_start = max(src_start, src_samples[1] + src_samples[2])

                # How many samples can we take from this source
                available = src_count - src_start
                if available <= 0:
                    continue

                take_count = min(samples_needed, available)
                if take_count > 0:
                    proc_assignments.append((
                        sample_idx,           # Target index in final array
                        src_proc,             # Source process
                        take_count,           # Number of samples
                        (src_proc, src_start, take_count)  # Source details (proc, start_idx, count)
                    ))
                    sample_idx += take_count
                    samples_needed -= take_count

                if samples_needed <= 0:
                    break

            assignments.append(proc_assignments)

        # Get assignments for this process
        my_assignments = assignments[accelerator.process_index]

        # Create new state dict with the correct structure but empty
        redistributed_dict = {}
        for key in filtered_state_dict:
            if isinstance(filtered_state_dict[key], list):
                redistributed_dict[key] = []
            elif isinstance(filtered_state_dict[key], torch.Tensor) and filtered_state_dict[key].dim() > 0:
                redistributed_dict[key] = []
            else:
                redistributed_dict[key] = filtered_state_dict[key]

        # Fill the redistributed dict with assigned samples
        for _, src_proc, count, (_, src_start, _) in my_assignments:
            src_file = dist_dir / f"samples_proc_{src_proc}.pkl"
            with open(src_file, "rb") as f:
                src_dict, _ = pickle.load(f)

                # Extract samples from source dict
                for key in redistributed_dict:
                    if key in src_dict:
                        if isinstance(src_dict[key], list):
                            redistributed_dict.setdefault(key, []).extend(
                                src_dict[key][src_start:src_start+count]
                            )
                        elif isinstance(src_dict[key], torch.Tensor) and src_dict[key].dim() > 0:
                            tensor_samples = src_dict[key][src_start:src_start+count]
                            if key not in redistributed_dict or not isinstance(redistributed_dict[key], list):
                                redistributed_dict[key] = []
                            redistributed_dict[key].append(tensor_samples)

        # Convert lists of tensors back to single tensors
        for key in redistributed_dict:
            if isinstance(redistributed_dict[key], list) and all(isinstance(item, torch.Tensor) for item in redistributed_dict[key]):
                try:
                    redistributed_dict[key] = torch.cat(redistributed_dict[key], dim=0)
                except (RuntimeError, ValueError):
                    # If tensors can't be concatenated, keep as list
                    pass

        # Update state_dict with redistributed values
        for key in redistributed_dict:
            state_dict[key] = redistributed_dict[key]

        accelerator.wait_for_everyone()

        if accelerator.is_main_process:
            for proc_idx in range(accelerator.num_processes):
                try:
                    count_file = dist_dir / f"count_proc_{proc_idx}.txt"
                    if count_file.exists():
                        count_file.unlink()

                    proc_file = dist_dir / f"samples_proc_{proc_idx}.pkl"
                    if proc_file.exists():
                        proc_file.unlink()
                except:
                    pass  # Ignore deletion errors

        print(f"Process {accelerator.process_index}: get {len(state_dict['final_reward'])} samples")

        return state_dict

    if DISTRL_DEBUG_LOG_BEST_SAMPLES is not None and len(keep_indices) > 0:
        save_dir = os.path.join(args.output_dir, "best_samples", f"step_{count}")
        os.makedirs(save_dir, exist_ok=True)

        # Log the best samples with rewards and class indices
        with open(os.path.join(save_dir, f"{accelerator.process_index}.txt"), "w") as f:
            for idx in keep_indices:
                reward = state_dict["final_reward"][idx].item()
                class_idxes_str = ""
                if "class_idxes" in state_dict:
                    # Get class indices for this sample and convert to string
                    # Multiple time steps are merged into one line
                    if isinstance(state_dict["class_idxes"], list):
                        class_idxes_str = str(state_dict["class_idxes"][idx])
                    else:
                        class_idxes_str = str(state_dict["class_idxes"][idx].tolist())
                f.write(f"{reward}, {class_idxes_str}\n")

    # Original in-memory approach for non-filesystem gathering
    # First, get counts from all processes
    local_count = torch.tensor([len(keep_indices)], device=accelerator.device)

    # Handle the gathering of counts based on distributed or single GPU setup
    if is_ddp and accelerator.num_processes > 1:
        # For multi-GPU: Gather counts from all processes to main process
        gathered_counts = accelerator.gather(local_count)
        all_counts = gathered_counts.tolist()
    else:
        # For single GPU: Just use the local count
        all_counts = [local_count.item()]

    total_samples = sum(all_counts)

    # If no samples kept, return empty state_dict
    if total_samples == 0:
        accelerator.print("No samples above threshold. Skipping rejection sampling.")
        for key in list(state_dict.keys()):
            if isinstance(state_dict[key], list):
                state_dict[key] = []
            elif isinstance(state_dict[key], torch.Tensor) and state_dict[key].dim() > 0:
                state_dict[key] = state_dict[key].new_empty((0,) + state_dict[key].shape[1:])
        return state_dict

    # For multi-GPU case, calculate target count per process for even distribution
    samples_per_process = total_samples // accelerator.num_processes
    remainder = total_samples % accelerator.num_processes

    # Calculate target count for current process (add 1 if in the first 'remainder' processes)
    target_count = samples_per_process + (1 if accelerator.process_index < remainder else 0)
    current_count = len(keep_indices)

    # Get source processes (those with too many samples) and destination processes (those with too few)
    source_procs = []
    dest_procs = []
    for i, count in enumerate(all_counts):
        proc_target = samples_per_process + (1 if i < remainder else 0)
        if count > proc_target:
            source_procs.append((i, count - proc_target))  # (process_id, samples to give)
        elif count < proc_target:
            dest_procs.append((i, proc_target - count))  # (process_id, samples to receive)

    # Log redistribution plan if this is the main process
    if accelerator.is_main_process and DISTRL_DEBUG_DETAIL_LOG:
        plan_str = "Redistribution plan:\n"
        for src, count in source_procs:
            plan_str += f"  Process {src} gives {count} samples\n"
        for dst, count in dest_procs:
            plan_str += f"  Process {dst} receives {count} samples\n"
        accelerator.print(plan_str)

    # Wait for all processes to reach this point
    accelerator.wait_for_everyone()

    # If this process has more samples than target, prepare to send some samples
    send_samples = {}
    recv_samples = {}

    # Step 1: Process will identify which samples to send if it's a source
    if current_count > target_count:
        # Keep track of which indices to keep and which to send
        samples_to_keep = list(range(target_count))
        samples_to_send = list(range(target_count, current_count))

        # Update state dict in-place to keep only target samples
        for key in state_dict:
            if isinstance(state_dict[key], list) and key in filtered_state_dict:
                # Prepare samples to send
                send_samples[key] = [filtered_state_dict[key][i] for i in samples_to_send]
                # Keep only target samples
                state_dict[key] = [filtered_state_dict[key][i] for i in samples_to_keep]
            elif isinstance(state_dict[key], torch.Tensor) and key in filtered_state_dict and state_dict[key].dim() > 0 and state_dict[key].size(0) > 0:
                # Prepare samples to send
                send_samples[key] = filtered_state_dict[key][samples_to_send].cpu()
                # Keep only target samples
                state_dict[key] = filtered_state_dict[key][samples_to_keep]
            elif key in filtered_state_dict:
                state_dict[key] = filtered_state_dict[key]

        print(f"Process {accelerator.process_index}: Keeping {len(samples_to_keep)} samples, redistributing {current_count - target_count}")
    else:
        # Update state dict in-place with all filtered values
        for key in filtered_state_dict:
            state_dict[key] = filtered_state_dict[key]

    # Step 2: Define a deterministic mapping of which sources send to which destinations
    # Calculate assignment matrix: which source sends how many samples to which destination
    assignments = []  # List of (src_proc, dst_proc, num_samples)

    # Sort processes to ensure deterministic assignments
    source_procs = sorted(source_procs)
    dest_procs = sorted(dest_procs)

    # Create assignments
    src_idx = 0
    for dst_proc, dst_need in dest_procs:
        remaining_need = dst_need
        while remaining_need > 0 and src_idx < len(source_procs):
            src_proc, src_avail = source_procs[src_idx]
            samples_to_transfer = min(remaining_need, src_avail)

            if samples_to_transfer > 0:
                assignments.append((src_proc, dst_proc, samples_to_transfer))
                remaining_need -= samples_to_transfer
                source_procs[src_idx] = (src_proc, src_avail - samples_to_transfer)

            if source_procs[src_idx][1] == 0:
                src_idx += 1

    # Print assignment plan
    if accelerator.is_main_process and DISTRL_DEBUG_DETAIL_LOG:
        plan_str = "Sample transfer plan:\n"
        for src, dst, count in assignments:
            plan_str += f"  Process {src} sends {count} samples to Process {dst}\n"
        accelerator.print(plan_str)

    # Wait for all processes to compute the same assignment plan
    accelerator.wait_for_everyone()

    # Step 3: Execute transfers based on assignments
    for src_proc, dst_proc, num_samples in assignments:
        # If I'm the source process in this assignment
        if accelerator.process_index == src_proc:
            # Pack all send_samples into a single object for transfer
            # This assumes all processes have the same keys in their state_dict
            transfer_package = {}
            for key in send_samples:
                if isinstance(send_samples[key], list):
                    # Take only the number of samples we need to send
                    transfer_package[key] = send_samples[key][:num_samples]
                    # Remove sent samples from our send buffer
                    send_samples[key] = send_samples[key][num_samples:]
                elif isinstance(send_samples[key], torch.Tensor) and send_samples[key].dim() > 0:
                    # Take only the number of samples we need to send
                    transfer_package[key] = send_samples[key][:num_samples]
                    # Remove sent samples from our send buffer
                    send_samples[key] = send_samples[key][num_samples:]

            # Convert to bytes and get size
            transfer_bytes = pickle.dumps(transfer_package)
            size = torch.tensor([len(transfer_bytes)], device=accelerator.device)

            # Send size first
            dist.send(size, dst=dst_proc)

            # Send actual data
            transfer_tensor = torch.ByteTensor(list(transfer_bytes)).to(accelerator.device)
            dist.send(transfer_tensor, dst=dst_proc)

            print(f"Process {src_proc}: Sent {num_samples} samples to Process {dst_proc}")

        # If I'm the destination process in this assignment
        elif accelerator.process_index == dst_proc:
            # Receive size first
            size = torch.tensor([0], device=accelerator.device)
            dist.recv(size, src=src_proc)

            # Prepare buffer and receive data
            transfer_tensor = torch.empty(size.item(), dtype=torch.uint8, device=accelerator.device)
            dist.recv(transfer_tensor, src=src_proc)

            # Convert back to dictionary
            transfer_bytes = bytes(transfer_tensor.cpu().numpy().tolist())
            received_package = pickle.loads(transfer_bytes)

            # Store received samples
            for key, value in received_package.items():
                if key not in recv_samples:
                    if isinstance(value, list):
                        recv_samples[key] = []
                    elif isinstance(value, torch.Tensor):
                        recv_samples[key] = []

                if isinstance(value, list):
                    recv_samples[key].extend(value)
                elif isinstance(value, torch.Tensor):
                    recv_samples[key].append(value)

            print(f"Process {dst_proc}: Received {num_samples} samples from Process {src_proc}")

    # Step 4: Update state_dict with received samples
    if accelerator.process_index in [dst for dst, _ in dest_procs]:
        for key in recv_samples:
            if key in state_dict:
                if isinstance(state_dict[key], list):
                    # For lists, simply extend
                    state_dict[key].extend(recv_samples[key])
                elif isinstance(state_dict[key], torch.Tensor) and state_dict[key].dim() > 0:
                    # For tensors, concatenate
                    if isinstance(recv_samples[key], list) and all(isinstance(t, torch.Tensor) for t in recv_samples[key]):
                        if len(recv_samples[key]) > 0:
                            all_tensors = [state_dict[key]] + [t.to(state_dict[key].device) for t in recv_samples[key]]
                            state_dict[key] = torch.cat(all_tensors, dim=0)

    # Wait for all processes to finish redistributing
    accelerator.wait_for_everyone()

    print(f"Process {accelerator.process_index}: get {len(state_dict['final_reward'])} samples finally, min reward: {min(state_dict['final_reward'])}, max reward: {max(state_dict['final_reward'])}")

    # Return the modified state_dict
    return state_dict

def _get_local_best_samples(args, state_dict, accelerator, count, is_ddp=True):
    """Get the locally best samples based on rewards.

    This function:
    1. Identifies local top rewards
    2. Filters local samples to keep only those above threshold

    Args:
        args: Command line arguments
        state_dict: Dictionary containing rollout data including rewards
        accelerator: Accelerator for distributed training
        count: Current training iteration count
        is_ddp: Whether using DistributedDataParallel (multi-GPU) or not (not used in this function)

    Returns:
        Original state_dict modified in-place to contain only high-reward samples
    """
    # Extract rewards from state_dict
    rewards = state_dict.get("final_reward", None)

    if rewards is None or len(rewards) == 0:
        accelerator.print("No rewards found in state_dict. Skipping sample selection.")
        return state_dict

    # Get local rewards tensor and get unique values
    local_rewards_tensor = rewards
    local_rewards_unique = torch.tensor(sorted(set(local_rewards_tensor.tolist())), device=local_rewards_tensor.device)

    # Find the top N rewards to create threshold
    num_to_keep = min(args.num_best_samples, len(local_rewards_unique))

    if num_to_keep <= 0:
        # If no samples to keep, return empty state_dict
        accelerator.print("No samples to keep. Skipping sample selection.")
        for key in list(state_dict.keys()):
            if isinstance(state_dict[key], list):
                state_dict[key] = []
            elif isinstance(state_dict[key], torch.Tensor) and state_dict[key].dim() > 0:
                state_dict[key] = state_dict[key].new_empty((0,) + state_dict[key].shape[1:])
        return state_dict

    # Sort unique rewards to find threshold
    sorted_rewards, _ = torch.sort(local_rewards_unique, descending=True)

    # Set threshold as the Nth highest reward
    threshold = sorted_rewards[min(num_to_keep - 1, len(sorted_rewards) - 1)].item()

    # Log threshold information
    accelerator.print(f"Local reward threshold: {threshold}")

    # Get indices of samples to keep based on threshold
    keep_indices = []
    for i, reward in enumerate(local_rewards_tensor):
        if reward >= threshold:
            keep_indices.append(i)

    # Log how many samples are being kept
    accelerator.print(f"Process {accelerator.process_index}: Keeping {len(keep_indices)} samples")

    # Filter state dict items to keep only selected indices
    for key in list(state_dict.keys()):
        if isinstance(state_dict[key], list):
            state_dict[key] = [state_dict[key][i] for i in keep_indices]
        elif isinstance(state_dict[key], torch.Tensor) and state_dict[key].dim() > 0 and state_dict[key].size(0) > 0:
            state_dict[key] = state_dict[key][keep_indices]

    print(f"Process {accelerator.process_index}: get {len(state_dict['final_reward'])} samples finally")

    return state_dict
