# semi_offline_grpo_correct.py
# Proper Semi-Offline GRPO: K samples per GPU → Aggregate → Train → Update model → Repeat

import os
import numpy as np
import torch
import torch.nn.functional as F
import torch.multiprocessing as mp
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, logging
from transformers import LogitsProcessor

import shutil
import pickle
import random
from datetime import datetime
import argparse
from harness_creator import Harness
import warnings 

warnings.filterwarnings("ignore", message="`torch.cuda.amp.autocast", category=FutureWarning)

# ---------------- CONFIGURATION (ALL CAPS) ----------------
logging.set_verbosity_error()

# PATHS
MODEL_PATH = "./models/search_sft_output/checkpoint-125000" 
BASE_CHECKPOINT_DIR = "./models/grpo_checkpoints"

# TRAINING HYPERPARAMETERS
NUM_GEN = 4                                    # Generations per turn
MAX_TURNS = 4                                  # Turns per sample
TOP_K = 5                                      # Top-K search results
EPS_CLIP = 0.2                                # PPO clipping parameter
KL_BETA = 0.1                                 # KL divergence penalty
LR = 1e-6                                     # Learning rate for training
MAX_NEW_TOK = 512                              # Max tokens per generation
MAX_CHECKPOINTS = 10                          # Max checkpoints to keep

# OFFLINE GRPO SETTINGS
TRAIN_AFTER_EVERY_K_SAMPLES_PER_GPU = 3       # Each GPU processes K samples, then we train
MAX_GLOBAL_STEPS_OF_OFFLINE_GRPO = 10         # Total number of global offline GRPO steps
MAX_EXPERIMENT_DIRS = 3                       # Keep only latest 3 experiment directories


# before saving
def generate_segment(model, tokenizer, prompt, _stop_token, device):
    # Tokenize once; pass fields explicitly (no **inputs) to avoid duplicates
    tok_out = tokenizer(prompt, return_tensors="pt")
    input_ids = tok_out.input_ids.to(device)
    attention_mask = tok_out.attention_mask.to(device) if "attention_mask" in tok_out else None

    # Ensure PAD/EOS are valid and consistent
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
    eos_id = tokenizer.eos_token_id

    with torch.no_grad(), torch.amp.autocast(device_type="cuda", enabled=False):
        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=MAX_NEW_TOK,
            do_sample=True,              
            temperature=0.7,
            eos_token_id=eos_id,
            pad_token_id=tokenizer.pad_token_id,
            use_cache=False
        )

    return tokenizer.decode(outputs[0][input_ids.shape[1]:], skip_special_tokens=True)


# ---------------- Experiment Directory Management ----------------
def create_experiment_directory(experiment_name):
    """Create timestamped experiment directory and manage experiment limit"""
    import time
    from datetime import datetime
    
    # Create base checkpoint directory if it doesn't exist
    os.makedirs(BASE_CHECKPOINT_DIR, exist_ok=True)
    
    # Create timestamped experiment directory
    #timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    experiment_dir = os.path.join(BASE_CHECKPOINT_DIR, f"experiment_{experiment_name}")
    os.makedirs(experiment_dir, exist_ok=True)
    
    # Manage experiment directory limit
    return experiment_dir


def multi_turn_inference_single_sample_harness(sample, model, tokenizer, harness, device):
    """Multi-turn inference using Harness data structure (dataset/split).
            sample keys: query_id, query_text, target_docs
    """
    query_text = sample['query_text']
    target_docs = sample['target_docs']
    
    context = f"<user_query>{query_text}</user_query>"
    
    all_rewards = []
    all_generations = []
    all_contexts = []
    sample_found_target = False
    harness_history = []
    harness_session = {
        'query_id': sample.get('query_id', ''),
        'query_text': query_text,
        'target_docs': [str(t) for t in target_docs]
    }
    
    for turn in range(MAX_TURNS):
        turn_rewards = []
        turn_generations = []
        turn_contexts = []
        turn_thinks = []
        turn_searches = []
        turn_evals = []
        corpus_size = len(harness.corpus)
        
        for gen_idx in range(NUM_GEN):
            think = generate_segment(model, tokenizer, context + "\n<think>", "</think>", device)
            search = generate_segment(model, tokenizer, 
                                          context + f"\n<think>{think}</think>\n<search_query>",
                                          "</search_query>", device)
            eval_result = harness.evaluate_search(search, target_docs, k=TOP_K)
            # Use best_score as cosine proxy (SentenceTransformer similarity)
            if eval_result['found_in_top_k']:
                sample_found_target = True
            # Combine normalized score and rank
            best_score = eval_result['best_score']
            best_score = float(eval_result.get('best_score', 0.0))
            score_norm = best_score if best_score >= 0.0 else (best_score + 1.0) / 2.0
            best_rank = int(eval_result.get('best_rank', -1))
            rank_norm = 1.0 - (best_rank / corpus_size) if best_rank != -1 and corpus_size > 0 else 0.0
            reward = 0.5 * score_norm + 0.5 * rank_norm
            
            full_generation = f"<think>{think}</think>\n<search_query>{search}</search_query>"
            turn_rewards.append(reward)
            turn_generations.append(full_generation)
            turn_contexts.append(context)
            turn_thinks.append(think)
            turn_searches.append(search)
            turn_evals.append(eval_result)
        
        best_idx = int(np.argmax(turn_rewards))
        best_think = turn_thinks[best_idx]
        best_search = turn_searches[best_idx]
        eval_result_best = turn_evals[best_idx]
        
        topk_txt = "Here are the Top-5 responses by the system:\n" + "\n".join(
            f"{i+1}. \"\"\"Doc: {d['doc_id']}\n{d['text']}\"\"\"" for i, d in enumerate(eval_result_best.get('results', [])[:TOP_K])
        )
        
        context += (
            f"\n<think>{best_think}</think>"
            f"\n<search_query>{best_search}</search_query>"
            f"\n<top_k_response>{topk_txt}</top_k_response>"
        )
        
        harness_history.append({
            'turn': turn + 1,
            'query': best_search,
            'eval_result': eval_result_best
        })
        
        all_rewards.append(turn_rewards)
        all_generations.append(turn_generations)
        all_contexts.append(turn_contexts)
    
    return {
        'rewards': all_rewards,
        'generations': all_generations,
        'contexts': all_contexts,
        'target_found': sample_found_target,
        'metadata': {
            'query': query_text
        },
        'session': harness_session,
        'history': harness_history,
        'success': any(t['eval_result']['found_in_top_k'] for t in harness_history),
        'total_turns': len(harness_history)
    }

def gpu_worker_collect_k_samples_harness(gpu_id, samples, k_samples, model_path, results_queue, position_offset, harness_map):
    device = torch.device(f"cuda:{gpu_id}") if torch.cuda.device_count() > 0 else torch.device("cpu")
    if torch.cuda.device_count() > 0:
        torch.cuda.set_device(device)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16).to(device)
    model.eval()
    
    gpu_samples = samples.copy()
    random.shuffle(gpu_samples)
    selected_samples = gpu_samples[:k_samples]
    
    gpu_results = []
    success_count = 0
    pbar = tqdm(total=k_samples, desc=f"GPU {gpu_id} (Harness)", position=position_offset + gpu_id, leave=True)
    for i, sample in enumerate(selected_samples):
        with torch.no_grad():
            ds_name = sample.get('dataset')
            harness = harness_map[ds_name] if ds_name in harness_map else next(iter(harness_map.values()))
            result = multi_turn_inference_single_sample_harness(sample, model, tokenizer, harness, device)
            gpu_results.append(result)
            if result['target_found']:
                success_count += 1
        success_rate = (success_count / (i + 1)) * 100
        pbar.set_description(f"GPU {gpu_id} (Harness): Success: {success_rate:.0f}% ({success_count}/{i + 1})")
        pbar.update(1)
    pbar.close()
    results_queue.put((gpu_id, gpu_results, success_count))

def collect_k_samples_per_gpu_harness(samples, num_gpus, k_samples, current_model_path, harness_map):
    ctx = mp.get_context('spawn')
    results_queue = ctx.Queue()
    processes = []
    position_offset = 1
    for gpu_id in range(max(1, num_gpus)):
        p = ctx.Process(
            target=gpu_worker_collect_k_samples_harness,
            args=(gpu_id, samples, k_samples, current_model_path, results_queue, position_offset, harness_map)
        )
        p.start()
        processes.append(p)
    all_gpu_results = []
    total_success_count = 0
    for _ in range(max(1, num_gpus)):
        gpu_id, gpu_results, success_count = results_queue.get()
        all_gpu_results.extend(gpu_results)
        total_success_count += success_count
    for p in processes:
        p.join()
    expected_samples = max(1, num_gpus) * k_samples
    expected_trajectories = expected_samples * MAX_TURNS * NUM_GEN
    actual_trajectories = len(all_gpu_results) * MAX_TURNS * NUM_GEN
    assert len(all_gpu_results) == expected_samples, \
        f"Sample count mismatch! Expected: {expected_samples}, Got: {len(all_gpu_results)}"
    assert actual_trajectories == expected_trajectories, \
        f"Trajectory count mismatch! Expected: {expected_trajectories}, Got: {actual_trajectories}"
    inference_success_rate = (total_success_count / expected_samples) * 100
    return all_gpu_results, inference_success_rate


def logprob(model, prefix_ids, completion_ids):
    seq = torch.cat([prefix_ids, completion_ids], dim=1)
    with torch.amp.autocast(device_type="cuda", enabled=False):
        logits = model(seq).logits.float()[:, prefix_ids.size(1)-1:-1, :]
    log_probs = F.log_softmax(logits, dim=-1)
    return log_probs.gather(2, completion_ids.unsqueeze(-1)).squeeze(-1).sum()

def train_one_epoch_on_collected_data(collected_samples, current_model_path, output_model_path, device="cuda:0"):
    """Train exactly one epoch on collected data with group-wise advantages"""
    
    # Load models FRESH from the provided path
    tokenizer = AutoTokenizer.from_pretrained(current_model_path)
    model = AutoModelForCausalLM.from_pretrained(current_model_path, torch_dtype=torch.bfloat16).to(device)
    ref_model = AutoModelForCausalLM.from_pretrained(current_model_path, torch_dtype=torch.bfloat16).to(device)
    model.config.use_cache = False
    model.gradient_checkpointing_enable()
    ref_model.gradient_checkpointing_enable()
    ref_model.config.use_cache = False
    ref_model.eval()
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    
    # Process collected data with group-wise advantages
    all_trajectories = []
    all_advantages = []
    global_rewards = []
    
    for sample_result in collected_samples:
        rewards = sample_result['rewards']  # [turns, generations] = [4, 4]
        generations = sample_result['generations']  # [turns, generations] = [4, 4]
        contexts = sample_result['contexts']  # [turns, generations] = [4, 4]
       
        # Process each turn separately (group-wise advantages)
        for turn_idx, (turn_rewards, turn_generations, turn_contexts) in enumerate(zip(rewards, generations, contexts)):
            # CRITICAL: Compute advantages within this group only (4 generations of same turn)
            turn_rewards_tensor = torch.tensor(turn_rewards, dtype=torch.float32)
            turn_advantages = (turn_rewards_tensor - turn_rewards_tensor.mean()) / (turn_rewards_tensor.std() + 1e-8)
            
            # Store trajectories and their group-computed advantages
            for gen_idx, (reward, generation, context, advantage) in enumerate(zip(
                turn_rewards, turn_generations, turn_contexts, turn_advantages
            )):
                all_trajectories.append({
                    'context': context,
                    'generation': generation,
                    'reward': reward
                })
                all_advantages.append(advantage.item())
                global_rewards.append(reward)
    
    total_trajectories = len(all_trajectories)
    expected_trajectories = len(collected_samples) * MAX_TURNS * NUM_GEN
    
    assert total_trajectories == expected_trajectories, \
        f"Training trajectory count mismatch! Expected: {expected_trajectories}, Got: {total_trajectories}"
    
    # Training loop - ONE EPOCH with tqdm
    model.train()
    total_loss = 0.0
    num_batches = 0
    
    # Training progress bar
    train_pbar = tqdm(
        total=total_trajectories,
        desc="Training",
        position=0,  # Main position for training
        leave=True
    )
    
    for i, (trajectory, advantage) in enumerate(zip(all_trajectories, all_advantages)):
        context = trajectory['context']
        generation = trajectory['generation']
        
        # Tokenize
        full_text = context + "\n" + generation
        context_ids = tokenizer(context, return_tensors="pt").input_ids.to(device)
        full_ids = tokenizer(full_text, return_tensors="pt").input_ids.to(device)
        generation_ids = full_ids[:, context_ids.size(1):]
        
        # Skip if generation is empty
        if generation_ids.size(1) == 0:
            train_pbar.update(1)
            continue
        
        # Compute policy gradients
        current_logprob = logprob(model, context_ids, generation_ids)
        
        with torch.no_grad():
            old_logprob = logprob(ref_model, context_ids, generation_ids)
        
        # GRPO loss computation
        diff = (current_logprob - old_logprob).clamp(-20, 20)
        ratio = torch.exp(diff)
        #ratio = torch.exp(current_logprob - old_logprob)
        clipped_ratio = torch.clamp(ratio, 1 - EPS_CLIP, 1 + EPS_CLIP)
        policy_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)
        
        # KL penalty
        kl_penalty = KL_BETA * (old_logprob - current_logprob)
        
        loss = policy_loss + kl_penalty
        total_loss += loss.item()
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        num_batches += 1
        
        # Update progress bar with current raw loss
        current_loss = loss.item()
        train_pbar.set_description(f"Training: Loss: {current_loss:.4f}")
        train_pbar.update(1)
    
    train_pbar.close()
    
    avg_total_loss = total_loss / num_batches if num_batches > 0 else 0.0
    
    # Save updated model
    os.makedirs(os.path.dirname(output_model_path), exist_ok=True)
    model.save_pretrained(output_model_path)
    

    if hasattr(tokenizer, "init_kwargs") and isinstance(tokenizer.init_kwargs, dict):
        tokenizer.init_kwargs.pop("torch_dtype", None)  # drop dtype

    # Optional full sanitize
    def to_py(v):
        if isinstance(v, np.generic):
            return v.item()
        if isinstance(v, (list, tuple)):
            return [to_py(x) for x in v]
        if isinstance(v, dict):
            return {k: to_py(x) for k, x in v.items()}
        return v

    if hasattr(tokenizer, "init_kwargs") and isinstance(tokenizer.init_kwargs, dict):
        tokenizer.init_kwargs = to_py(tokenizer.init_kwargs)

    # Ensure pad/eos ids are plain ints
    for attr in ["eos_token_id", "pad_token_id"]:
        if hasattr(tokenizer, attr):
            val = getattr(tokenizer, attr)
            if isinstance(val, np.generic):
                setattr(tokenizer, attr, int(val))
        
    tokenizer.save_pretrained(output_model_path)
    
    # Compute metrics
    rewards_tensor = torch.tensor(global_rewards, dtype=torch.float32)
    success_count = sum(1 for r in global_rewards if r > 0.5)
    success_rate = success_count / len(global_rewards) * 100
    
    return {
        'avg_loss': avg_total_loss,
        'success_rate': success_rate,
        'total_trajectories': total_trajectories,
        'total_samples': len(collected_samples),
        'avg_reward': rewards_tensor.mean().item()
    }

# ---------------- Checkpoint Management ----------------
class ExperimentManager:
    """Simple experiment manager that saves everything in timestamped directory"""
    def __init__(self, experiment_dir, max_models=3):
        self.experiment_dir = experiment_dir
        self.max_models = max_models
        
        # Create subdirectories
        self.pkl_dir = os.path.join(experiment_dir, "collected_data")
        self.models_dir = os.path.join(experiment_dir, "updated_models")
        
        os.makedirs(self.pkl_dir, exist_ok=True)
        os.makedirs(self.models_dir, exist_ok=True)
    
    def get_pkl_path(self, global_step):
        """Get path for saving collected data pkl file"""
        return os.path.join(self.pkl_dir, f"collected_global_step_{global_step}.pkl")
    
    def get_updated_model_path(self, global_step):
        """Get path for saving updated model after training"""
        return os.path.join(self.models_dir, f"grpo_updated_step_{global_step}")
    
    def cleanup_old_models(self):
        """Keep only the latest MAX_EXPERIMENT_DIRS models in updated_models"""
        if not os.path.exists(self.models_dir):
            return
        
        # Get all model directories
        model_dirs = []
        for item in os.listdir(self.models_dir):
            item_path = os.path.join(self.models_dir, item)
            if os.path.isdir(item_path) and item.startswith("grpo_updated_step_"):
                try:
                    # Extract step number for sorting
                    step_num = int(item.replace("grpo_updated_step_", ""))
                    model_dirs.append((step_num, item_path))
                except ValueError:
                    continue
        
        # Sort by step number (newest first)
        model_dirs.sort(key=lambda x: x[0], reverse=True)
        
        # Remove old model directories
        if len(model_dirs) > self.max_models:
            for step_num, dir_path in model_dirs[self.max_models:]:
                if os.path.exists(dir_path):
                    shutil.rmtree(dir_path)

# ---------------- Main Training Loop ----------------
def main():
    """Main semi-offline GRPO training loop"""
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", type=str, default="scifact", help="Single dataset name (used if --datasets not set)")
    parser.add_argument("--datasets", type=str, default=None, help="Comma-separated list of datasets to combine from split")
    parser.add_argument("--per_dataset_samples", type=int, default=1000, help="Max samples to draw per dataset")
    parser.add_argument("--split", type=str, default="train", help="Split to use in Harness mode (train/val/test)")
    parser.add_argument("--retriever_model", type=str, default="all-MiniLM-L6-v2", help="SentenceTransformer model for Harness")
    parser.add_argument("--sft_model_path", type=str, default=None, help="Path to sft model to use")
    parser.add_argument("--experiment-name", type=str, required=True, help="Name of the experiment")

    args = parser.parse_args()

    
    # Create timestamped experiment directory
    experiment_dir = create_experiment_directory(args.experiment_name)
    
    num_gpus = torch.cuda.device_count()

    # Build combined samples from one or multiple datasets
    harness_samples = []
    harness_map = {}

    def build_samples_for(h, limit):
        qrels = h.qrels
        queries = h.queries
        grouped = qrels.groupby(qrels.columns[0])
        samples = []
        for qid, group in grouped:
            qid_str = str(qid)
            # For datasets with integer ids in queries, handle conversion
            if qid_str not in queries.index:
                try:
                    qtext = queries.loc[int(qid_str), 'text']
                except Exception:
                    continue
            else:
                qtext = queries.loc[qid_str, 'text']
            tdocs = group.iloc[:, 1].astype(str).tolist()
            samples.append({
                'query_id': qid_str,
                'query_text': qtext,
                'target_docs': tdocs,
                'dataset': h.dataset
            })
        random.shuffle(samples)
        return samples[:limit] if (limit and limit > 0) else samples

    if args.datasets:
        ds_list = [d.strip() for d in args.datasets.split(',') if d.strip()]
        for ds in ds_list:
            h = Harness(ds, args.split, args.retriever_model)
            harness_map[ds] = h
            harness_samples.extend(build_samples_for(h, args.per_dataset_samples))
    else:
        h = Harness(args.dataset, args.split, args.retriever_model)
        harness_map[h.dataset] = h
        harness_samples.extend(build_samples_for(h, args.per_dataset_samples))
    
    # Pre-compute expected numbers
    samples_per_global_step = num_gpus * TRAIN_AFTER_EVERY_K_SAMPLES_PER_GPU
    trajectories_per_global_step = samples_per_global_step * MAX_TURNS * NUM_GEN
    
    # Initialize experiment manager
    experiment_manager = ExperimentManager(experiment_dir, MAX_EXPERIMENT_DIRS)
    
    # Current model path (starts with initial, gets updated each step)
    current_model_path = args.sft_model_path
    
    # Main progress bar for global steps
    main_pbar = tqdm(
        total=MAX_GLOBAL_STEPS_OF_OFFLINE_GRPO,
        desc="Global GRPO Steps",
        position=0,
        leave=True
    )
    
    # Global training loop
    for global_step in range(1, MAX_GLOBAL_STEPS_OF_OFFLINE_GRPO + 1):
        
        # Phase 1: Collect K samples
        collected_samples, inference_success_rate = collect_k_samples_per_gpu_harness(
            harness_samples, num_gpus, TRAIN_AFTER_EVERY_K_SAMPLES_PER_GPU, current_model_path, harness_map
        )

        # Save collected data in experiment directory
        collected_path = experiment_manager.get_pkl_path(global_step)
        with open(collected_path, 'wb') as f:
            pickle.dump(collected_samples, f)
        
        # Phase 2: Train one epoch on collected data
        updated_model_path = experiment_manager.get_updated_model_path(global_step)
        
        metrics = train_one_epoch_on_collected_data(
            collected_samples, current_model_path, updated_model_path
        )
        
        # Update current model path for next global step
        current_model_path = updated_model_path
        
        # Clean up old models to keep only latest MAX_EXPERIMENT_DIRS
        #experiment_manager.cleanup_old_models()
        
        # Print the required dict
        result_dict = {
            "global_step_idx": global_step,
            "inference_success_rate": round(inference_success_rate, 1),
            "training_loss": round(metrics['avg_loss'], 4)
        }
        print(result_dict)
        
        # Update main progress bar
        main_pbar.update(1)
    
    main_pbar.close()

if __name__ == "__main__":
    main()