import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizer, AutoConfig

from trl import ModelConfig, get_kbit_device_map, get_quantization_config

from grpo_configs_v2 import GRPOConfig, SFTConfig

import os
import ast
from typing import Any

import pandas as pd


########################
# LOGGING #################
########################

def init_wandb_training(training_args):
    """
    Helper function for setting up Weights & Biases logging tools.
    """
    if training_args.wandb_entity is not None:
        os.environ["WANDB_ENTITY"] = training_args.wandb_entity
    if training_args.wandb_project is not None:
        os.environ["WANDB_PROJECT"] = training_args.wandb_project
    if training_args.wandb_run_group is not None:
        os.environ["WANDB_RUN_GROUP"] = training_args.wandb_run_group
    if training_args.wandb_run_name is not None:
        os.environ["WANDB_NAME"] = training_args.wandb_run_name
        
##############  
# DATA PROCESSING
##############

# grpo_utils.py (or wherever explode_qa_pairs lives)
import json, ast, math, logging

def _parse_qa_cell(cell):
    """
    Return a list[dict] from any of:
      • real list/dict (already parsed)
      • JSON string
      • Python‑literal string (single quotes, None, etc.)
      • NaN / None  ➜ []
    """
    # 1. already a list/dict
    if isinstance(cell, (list, tuple)):
        return cell

    # 2. missing value (np.nan, None, "nan", etc.)
    if cell is None or (isinstance(cell, float) and math.isnan(cell)):
        return []

    if isinstance(cell, str):
        cell = cell.strip()
        if not cell:
            return []
        # 3. try JSON first
        try:
            return json.loads(cell)
        except Exception:
            pass
        # 4. try Python literal
        try:
            return ast.literal_eval(cell)
        except Exception as e:
            logging.warning("Could not parse qa_pairs cell: %s  (%s)", cell[:120], e)
            return []

    logging.warning("Unexpected qa_pairs type: %r", type(cell))
    return []

def _safe_json(obj: Any) -> str:
    """Always return a JSON string ("" for NaN / None)."""
    if obj is None or (isinstance(obj, float) and math.isnan(obj)):
        return ""
    # Already a string & looks like JSON → keep as is
    if isinstance(obj, str):
        return obj
    return json.dumps(obj, ensure_ascii=False)

# helper: explode *one* DataFrame into a new, row‑per‑QA DataFrame
def explode_pdf(df: pd.DataFrame) -> pd.DataFrame:
    rows = []
    for i, row in df.iterrows():
        id_       = row["identifier"]
        timeline  = row["past_timeline"]
        qa_list   = _parse_qa_cell(row["qa_pairs"])

        for qa in qa_list:
            if i == 0:
                print(qa)
                logging.info("QA pair found during processing: %s", qa)
            rows.append(
                {
                    "identifier":            id_,
                    "past_timeline":         timeline,
                    "question":              qa.get("question", ""),
                    "final_answer":          qa.get("final_answer", ""),
                    "answer_reasoning":      qa.get("answer_reasoning", ""),
                    "action_space_category": qa.get("action_space_category", ""),
                    "source":                _safe_json(qa.get("source", "")),
                    "qa_pair":               _safe_json(qa),
                }
            )
            # log the qa_pair value for the first row
            if i == 0:
                logging.warning(f"qa_pair value for first row: {qa}")
    return pd.DataFrame(rows)

# def explode_qa_pairs(example):
#     """
#     Turn a single row (with list‑of‑dicts `qa_pairs`) into one row PER QA pair.
#     We also copy the top‑level fields we still need (identifier, past_timeline …).

#     The function is batched, so `example[...]` are *lists*.
#     """
#     out = {
#         "identifier":   [],
#         "past_timeline":[],
#         "question":     [],
#         "final_answer": [],
#         "answer_reasoning": [],
#         "action_space_category": [],
#         "source":     [],
#         "qa_pair":      [],          # a full copy for judge‑based rewards
#     }

#     for id_, timeline, qa_cell in zip(
#             example["identifier"], example["past_timeline"], example["qa_pairs"]):
#         qa_list = _parse_qa_cell(qa_cell)
#         for qa in qa_list:                      # ← three iterations per row
#             out["identifier"].append(id_)
#             out["past_timeline"].append(timeline)
#             out["question"].append(qa["question"])
#             out["final_answer"].append(qa["final_answer"])
#             out["answer_reasoning"].append(qa["answer_reasoning"])
#             out["action_space_category"].append(qa.get("action_space_category", ""))
#             out["source"].append(_safe_json(qa.get("source", "")))
#             out["qa_pair"].append(_safe_json(qa))          # whole dict (may hold action_space, etc.), used for reward calculations

#     return out



###############
# MODEL LOADING
###############

# FOR LLAMA

# def get_tokenizer(
#     model_args: ModelConfig,
#     training_args: SFTConfig | GRPOConfig,
#     representation_mode: str | None = None,
# ) -> PreTrainedTokenizer:
#     tokenizer = AutoTokenizer.from_pretrained(
#         model_args.model_name_or_path,
#         revision=model_args.model_revision,
#         trust_remote_code=model_args.trust_remote_code,
#     )

#     if training_args.chat_template:
#         tokenizer.chat_template = training_args.chat_template

#     # add a pad token if the checkpoint doesn’t ship one
#     if tokenizer.pad_token is None:
#         tokenizer.add_special_tokens({"pad_token": "[PAD]"})

#     # tic‑tac‑toe move tokens
#     if representation_mode == "special":
#         special_tokens = [f"<move_{i}>" for i in range(1, 19)]
#         # tokenizer.add_tokens([tok for tok in special_tokens if tok not in tokenizer.get_vocab()])
#         tokenizer.add_special_tokens({"additional_special_tokens": special_tokens})

#     return tokenizer


# # -------------------------------------------------
# # MODEL
# # -------------------------------------------------
# def get_model(
#     model_args: ModelConfig,
#     training_args: SFTConfig | GRPOConfig,
#     tokenizer: PreTrainedTokenizer,
# ) -> AutoModelForCausalLM:
#     torch_dtype = (
#         model_args.torch_dtype
#         if model_args.torch_dtype in {"auto", None}
#         else getattr(torch, model_args.torch_dtype)
#     )

#     quant_cfg = get_quantization_config(model_args)

#     model_kwargs = dict(
#         revision=model_args.model_revision,
#         trust_remote_code=model_args.trust_remote_code,
#         attn_implementation=model_args.attn_implementation,
#         torch_dtype=torch_dtype,
#         use_cache=not training_args.gradient_checkpointing,
#         device_map=get_kbit_device_map() if quant_cfg else None,
#         quantization_config=quant_cfg,
#         load_in_8bit=model_args.load_in_8bit,
#         load_in_4bit=model_args.load_in_4bit,
#     )

#     # -------------------------------------------------
#     # decide whether we’ll need ignore_mismatched_sizes
#     # -------------------------------------------------
#     config = AutoConfig.from_pretrained(
#         model_args.model_name_or_path,
#         revision=model_args.model_revision,
#         trust_remote_code=model_args.trust_remote_code,
#     )
#     ignore_mismatch = len(tokenizer) != config.vocab_size
    
#     logger = logging.getLogger(__name__)
#     if ignore_mismatch:
#         logger.warning(
#             "Tokenizer size (%d) does not match model config vocab size (%d). "
#             "This may be due to a tokenizer that was trained with a different vocabulary.",
#             len(tokenizer),
#             config.vocab_size,
#         )

#     model = AutoModelForCausalLM.from_pretrained(
#         model_args.model_name_or_path,
#         ignore_mismatched_sizes=ignore_mismatch,
#         **model_kwargs,
#     )
    

#     # resize if the tokenizer grew
#     # if len(tokenizer) != model.get_input_embeddings().weight.size(0):
#     model.resize_token_embeddings(len(tokenizer))
    
#     logger.info(
#         "Model loaded with %d tokens in the input embeddings.",
#         model.get_input_embeddings().weight.size(0),
#     )

#     return model

# ORIGINAL
def get_tokenizer(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> PreTrainedTokenizer:
    """Get the tokenizer for the model."""
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
    )

    if training_args.chat_template is not None:
        tokenizer.chat_template = training_args.chat_template

    return tokenizer


def get_model(model_args: ModelConfig, training_args: SFTConfig | GRPOConfig) -> AutoModelForCausalLM:
    """Get the model"""
    torch_dtype = (
        model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
    )
    quantization_config = get_quantization_config(model_args)
    model_kwargs = dict(
        revision=model_args.model_revision,
        trust_remote_code=model_args.trust_remote_code,
        attn_implementation=model_args.attn_implementation,
        torch_dtype=torch_dtype,
        use_cache=False if training_args.gradient_checkpointing else True,
        device_map=get_kbit_device_map() if quantization_config is not None else None,
        quantization_config=quantization_config,
        load_in_8bit=model_args.load_in_8bit,
        load_in_4bit=model_args.load_in_4bit,
        # low_cpu_mem_usage=model_args.low_cpu_mem_usage,
    )    
    
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        # ignore_mismatched_sizes=True,
        **model_kwargs,
    )
    return model