import json
from typing import List

import torch
from datasets import Dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer

import wandb
from reward_funcs import make_curriculum_reward_function

# ==================== Configuration ====================

GENERATOR_SYSTEM_PROMPT = (
    """You are a helpful assistant that writes coherent and creative short stories."""
)

MODEL_NAME = "Qwen/Qwen3-4B"

CRITERIA_ORDER = [
    "prompt_alignment",
    "plot_coherence_and_progression",
    "character_depth_and_motivation",
    "narrative_tension_and_stakes",
    "show_dont_tell",
    "emotional_resonance",
    "thematic_subtlety",
    "pacing_and_structure",
    "atmosphere_and_sensory_detail",
    "dialogue_and_relationships",
    "opening_and_ending_effectiveness",
    "narrative_voice_and_style",
    "originality_and_unresolved_mystery",
]


# ==================== Training State ====================
class TrainingState:
    """Manages training state and model references."""

    def __init__(self):
        self.documents = []
        self.proposer_model = None
        self.proposer_tokenizer = None
        self.generator_model = None
        self.generator_tokenizer = None
        self.generated_questions = []
        self.current_step = 0

    def get_active_criteria(self) -> List[str]:
        """Get criteria that should be active at current step."""
        # Every 100 steps, add one more criterion
        return CRITERIA_ORDER[8:9]


# Global state
state = TrainingState()


# ==================== Dataset Creation ====================
def create_generator_dataset(questions: List[str], num_steps: int = 200) -> Dataset:
    """Create dataset for generator training, cycling through questions."""
    data = []

    for step in range(num_steps):
        # Cycle through questions
        question = questions[step % len(questions)]

        prompt = [
            {"role": "system", "content": GENERATOR_SYSTEM_PROMPT},
            {"role": "user", "content": question},
        ]

        formatted_prompt = state.generator_tokenizer.apply_chat_template(
            prompt, tokenize=False, add_generation_prompt=True, enable_thinking=False
        )

        data.append({"prompt": formatted_prompt})

    return Dataset.from_list(data)


# ==================== Main Training Function ====================
def train_generator_with_rubric(
    documents_path: str = "/home/t-smotwani/selfplay/newyorker/story*.txt",
    num_questions: int = 50,
    training_steps: int = 200,
    learning_rate: float = 5e-7,
    output_dir: str = "outputs/rubric-trained",
    resume_from_checkpoint_path: str | None = None,
    wandb_run_id: str | None = None,
    checkpoint_iteration_num: int | None = None,
):
    """
    Train generator model with progressive rubric evaluation.

    Args:
        documents_path: Path pattern for story files
        num_questions: Number of questions to generate upfront
        training_steps: Total training steps
        learning_rate: Learning rate for training
        output_dir: Directory to save final model
    """

    # Initialize wandb
    if wandb_run_id:
        run = wandb.init(project="rubric-training", id=wandb_run_id, resume="must")
        print(f"Resumed wandb run: {wandb_run_id}")
    else:
        run = wandb.init(
            project="rubric-training",
            name="qwen3-4B-nothink-reward-hacking-pairwise-rewards-atmosphere-lr-constant",
        )

    if not resume_from_checkpoint_path:
        art = wandb.Artifact(
            f"code-{MODEL_NAME}-{run.id}".replace("/", "_"), type="code"
        )
        art.add_file(
            ""
        )
        art.add_file(
            ""
        )
        art.add_file("")
        art.add_file(
            ""
    
        run.log_artifact(art)
        print("Saved training scripts to wandb")

    print("Skipping PHASE 1: Generating Questions - using existing questions from file")
    with open("generated_questions_qwen3.json", "r") as json_file:
        state.generated_questions = json.load(json_file)
    print(f"\nLoaded {len(state.generated_questions)} questions from file")

    # Initialize generator model for training
    print("\nInitializing generator model for training...")
    state.generator_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.float16,
        attn_implementation="sdpa",
        device_map="auto",
    )
    state.generator_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    state.generator_tokenizer.pad_token = state.generator_tokenizer.eos_token

    # Create dataset
    print("\nCreating training dataset...")
    train_dataset = create_generator_dataset(state.generated_questions, training_steps)

    # Configure training
    print("\n" + "=" * 60)
    print("PHASE 2: Training Generator with Progressive Rubric")
    print("=" * 60)

    config = GRPOConfig(
        output_dir=f"{output_dir}/generator-grpo",
        run_name="generator-rubric-training",
        learning_rate=learning_rate,
        lr_scheduler_type="constant",
        per_device_train_batch_size=2,
        gradient_accumulation_steps=4,
        fp16=True,
        top_p=0.8,
        top_k=20,
        num_generations=8,
        max_prompt_length=768,
        max_completion_length=1024,
        max_steps=training_steps,
        logging_steps=1,
        save_steps=10,
        save_total_limit=25,
        report_to="wandb",
        loss_type="dr_grpo",
        eval_strategy="no",
        save_strategy="steps",
        beta=0.0001,
    )

    peft_config = LoraConfig(
        r=32,
        lora_alpha=32,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "up_proj",
            "down_proj",
            "gate_proj",
        ],
        task_type="CAUSAL_LM",
        lora_dropout=0.05,
        bias="none",
    )

    trainer = GRPOTrainer(
        model=state.generator_model,
        processing_class=state.generator_tokenizer,
        reward_funcs=[
            make_curriculum_reward_function(state, MODEL_NAME),
        ],
        args=config,
        train_dataset=train_dataset,
        peft_config=peft_config,
    )

    # Train
    if resume_from_checkpoint_path and checkpoint_iteration_num:
        print(f"\nResuming training from checkpoint: {resume_from_checkpoint_path}")
        state.current_step = checkpoint_iteration_num
        trainer.train(resume_from_checkpoint=resume_from_checkpoint_path)
    else:
        trainer.train()

    # Save final model
    print("\n" + "=" * 60)
    print("Saving final model...")
    print("=" * 60)

    print("\nTraining complete!")
    print(f"Trained for {training_steps} steps with progressive rubric evaluation")
    print(f"Final criteria count: {len(state.get_active_criteria())}")


# ==================== Main Entry Point ====================
if __name__ == "__main__":
    train_generator_with_rubric(
        documents_path="",
        num_questions=50,
        training_steps=250,
        learning_rate=1e-5,
        output_dir="",
    )
