# train.py: Script for training a model using the dataset generated by the previous script.
# This script takes arguments to specify the dataset and other configurations.
#
# example launch command:
#     torchrun --nproc_per_node=NUM_GPUS axbench/scripts/train.py --config axbench/demo/sweep/train.yaml
import os
import argparse
import yaml
import json
import glob
import pickle
import torch
import shutil
import requests
import pandas as pd
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
from pathlib import Path
from args.training_args import TrainingArgs
from args.dataset_args import DatasetArgs
from axbench.utils.constants import * 
from axbench.utils.model_utils import get_prefix_length, get_suffix_length
from transformers import set_seed
import torch.distributed as dist
import sys
from torch.utils.data import DataLoader
from axbench.models.sae import save_pruned_sae
import wandb
import datetime

# all supported methods
import axbench

import logging

# Initialize the logger
logger = logging.getLogger(__name__)

CONFIG_FILE = "config.json"
STATE_FILE = "train_state.pkl"
METADATA_FILE = "metadata.jsonl"


def data_generator(data_dir, use_dpo_loss=False):
    """
    Generator function to read multiple data files and yield data subsets by concept_id.
    Processes files in order: train_data.parquet, train_data_0.parquet, train_data_1.parquet, etc.

    Args:
        data_dir (str): Path to the data directory.

    Yields:
        (concept_id, df_subset): A tuple containing the concept_id and subset DataFrame.
    """
    # Gather all file paths in the directory
    if use_dpo_loss:
        file_paths = [os.path.join(data_dir, f) for f in os.listdir(data_dir) \
            if f.startswith('dpo_train_data') and f.endswith('.parquet') and "combined" not in f]
    else:
        file_paths = [os.path.join(data_dir, f) for f in os.listdir(data_dir) \
            if f.startswith('train_data') and f.endswith('.parquet') and "combined" not in f]

    # Sort files: 'train_data.parquet' comes first, then 'train_data_X.parquet' sorted by X
    def extract_index(file_name):
        if use_dpo_loss:
            if file_name == 'dpo_train_data.parquet':
                return -1  # Ensure 'train_data.parquet' comes first
            else:
                # Extract the number X from 'train_data_X.parquet'
                return int(file_name.split('_')[-1].split('.')[0])
        else:
            if file_name == 'train_data.parquet':
                return -1  # Ensure 'train_data.parquet' comes first
            else:
                # Extract the number X from 'train_data_X.parquet'
                return int(file_name.split('_')[-1].split('.')[0])

    file_paths.sort(key=lambda x: extract_index(os.path.basename(x)))

    for file_path in file_paths:
        df = pd.read_parquet(file_path)
        concept_ids = df['concept_id'].unique()
        concept_ids.sort()
        for concept_id in concept_ids:
            if concept_id >= 0:
                # print(f"Processing concept_id {concept_id}")
                df_subset = df[df['concept_id'] == concept_id]
                yield (concept_id, df_subset)


def load_metadata(metadata_path):
    """
    Load metadata from a JSON lines file.
    """
    metadata = []
    with open(metadata_path, 'r') as f:
        for line in f:
            data = json.loads(line)
            metadata += [data]  # Return the metadata as is
    return metadata


def prepare_df(
    original_df, negative_df, concept, metadata, tokenizer, 
    binarize, train_on_negative, is_chat_model, output_length, model_name, 
    max_num_of_examples=None, use_dpo_loss=False, steering_prompt_type="prepend",
    keep_orig_axbench_format=False):
    
    suffix_length, suffix_str = get_suffix_length(tokenizer)
    print(f"Suffix length for {model_name}: {suffix_length}, Suffix string: {suffix_str}")
    genre = metadata["concept_genres_map"][concept][0]
    # assign input and output containing concept with 1, otherwise 0
    positive_df = original_df[(original_df["output_concept"] == concept) & (original_df["category"] == "positive")]
    negative_df = negative_df[(negative_df["concept_genre"] == genre)]
    if max_num_of_examples:
        positive_df = positive_df.head(max_num_of_examples // 2)
        negative_df = negative_df.head(max_num_of_examples // 2)
    if binarize:
        assert False, "Non-binarizing dataset is not even supported here man. Only in original AxBench."
    else:
        # if not binarizing, we need to apply the chat template to the input. It becomes a standard instruction tuning task.
        if not use_dpo_loss and train_on_negative:
            all_df = pd.concat([positive_df, negative_df], axis=0)
        else:
            # for DPO, we only use positive examples.
            all_df = positive_df
        if is_chat_model:
            system_messages = []
            if model_name in HAS_SYSTEM_PROMPT_MODELS:
                system_messages = [{"role": "system", "content": "You are a helpful assistant."}]
            
            def apply_chat_template(df, column_name):
                def template_function(row):
                    messages = system_messages + [{"role": "user", "content": row[column_name]}]
                    nobos = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True)[1:]
                    return tokenizer.decode(nobos)
                df[column_name] = df.apply(template_function, axis=1)

            apply_chat_template(all_df, "input")
            if use_dpo_loss:
                if f"{steering_prompt_type}_steered_input" in all_df.columns:
                    apply_chat_template(all_df, f"{steering_prompt_type}_steered_input")

            # Add EOS prefix tokens by default. The truncation at data collator will take care of the rest.
            def apply_output_template(df, column_name):
                def template_function(row):
                    return row[column_name] + suffix_str
                df[column_name] = df.apply(template_function, axis=1)
            
            # AxBench has much shorter outputs. We follow the original AxBench format.
            if not keep_orig_axbench_format:
                # Apply the template to all output columns
                for column in ["output", "winning_output", "losing_output", "prepend_steered_output", "blend_in_steered_output"]:
                    if column in all_df.columns:
                        apply_output_template(all_df, column)

            # Print sample row data
            print("\n=== Sample Row Data ===")
            sample_row = all_df.iloc[0]
            for column in sample_row.index:
                print(f"\n{column}:")
                print("-" * (len(column) + 1))
                print(f"{sample_row[column]}")
            print("=====================\n")

        return all_df # do nothing, the task will be standard instruction tuning.


def partition_list(lst, n):
    """
    Partition a list into n approximately equal slices.

    Args:
        lst (list): The list to partition.
        n (int): The number of partitions.

    Returns:
        list of lists: A list containing n sublists.
    """
    k, m = divmod(len(lst), n)
    return [lst[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)]


def load_state(dump_dir, rank):
    """
    Load the state from a file if it exists.
    """
    state_path = os.path.join(f"{dump_dir}", f"{STATE_FILE}_rank_{rank}")
    if os.path.exists(state_path):
        with open(state_path, "rb") as f:
            return pickle.load(f)
    return None


def save_state(dump_dir, state, concept_metadata, rank):
    dump_dir = Path(dump_dir)
    dump_dir.mkdir(parents=True, exist_ok=True)
    # Save state
    state_path = os.path.join(dump_dir, f"{STATE_FILE}_rank_{rank}")
    with open(state_path, "wb") as f:
        pickle.dump(state, f)

    # Save metadata again
    metadata_path = os.path.join(dump_dir, f"rank_{rank}_{METADATA_FILE}")
    with open(metadata_path, "a") as f:
        f.write(json.dumps(concept_metadata) + "\n")

def main():
   
    args = TrainingArgs(section="train")
    generate_args = DatasetArgs(section="generate")

    # Initialize the process group
    dist.init_process_group(backend='nccl', init_method='env://', 
                          timeout=datetime.timedelta(seconds=6000))

    # Get the rank and world_size from environment variables
    rank = dist.get_rank()
    world_size = dist.get_world_size()
    local_rank = int(os.environ.get('LOCAL_RANK', 0))

    # Set the device for this process
    device = torch.device(f'cuda:{local_rank}')
    torch.cuda.set_device(device)

    # Set a unique seed per rank for reproducibility
    set_seed(args.seed + rank)

    if args.overwrite_data_dir and Path(args.overwrite_data_dir).exists():
        logger.warning(f"Overwriting data directory {args.data_dir}")
        args.data_dir = args.overwrite_data_dir
    else:
        args.data_dir = f"{args.dump_dir}/generate"

    # Configure the logger per rank
    logger.setLevel(logging.WARNING)  # Set the logging level as desired

    # Create a logging formatter that includes the rank
    formatter = logging.Formatter(
        fmt=f'%(asctime)s,%(msecs)03d %(levelname)-8s [Rank {rank}] [%(filename)s:%(lineno)d] %(message)s',
        datefmt='%Y-%m-%d:%H:%M:%S'
    )

    # Create a console handler and set its formatter
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setFormatter(formatter)

    # Add the handler to the logger
    if not logger.handlers:
        logger.addHandler(console_handler)

    # Optionally, create a file handler per rank
    """
    log_file = f'log_rank_{rank}.log'
    file_handler = logging.FileHandler(log_file)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)
    """

    # Load dataset and metadata
    metadata_path = os.path.join(args.data_dir, 'metadata.jsonl')
    metadata = load_metadata(metadata_path)
    df_generator = data_generator(args.data_dir, use_dpo_loss=args.use_dpo_loss)
    if args.use_dpo_loss:
        all_df = pd.read_parquet(os.path.join(args.data_dir, 'dpo_train_data.parquet'))
    else:
        all_df = pd.read_parquet(os.path.join(args.data_dir, 'train_data.parquet')) # this is needed for binarizing the dataset
    negative_df = all_df[(all_df["output_concept"] == EMPTY_CONCEPT) & (all_df["category"] == "negative")]

    df_list = list(df_generator)
    logger.warning(f"Total number of concept df loaded: {len(df_list)}")
    if args.max_concepts:
        logger.warning(f"All ranks only processing {args.max_concepts} concepts")
        df_list = df_list[:args.max_concepts]

    dump_dir = Path(args.dump_dir) / "train"
    dump_dir.mkdir(parents=True, exist_ok=True)
    
    # save pruned SAE
    sae_params = None # TODO: this is a workaround to avoid breaking the code.
    # if rank == 0:
    #     sae_params = save_pruned_sae(metadata_path, dump_dir)

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name, truncation=True, 
        model_max_length=int(args.output_length),
        padding_side="right", use_fast=False)

    # Partition df_list among ranks
    df_list_per_rank = partition_list(df_list, world_size)
    my_df_list = df_list_per_rank[rank]

    # Load model instance onto device
    if args.use_bf16:
        logger.warning(f"Using bfloat16 for model {args.model_name}")
    if "gemma-3" in args.model_name:
        from transformers import Gemma3ForCausalLM
        model_instance = Gemma3ForCausalLM.from_pretrained(
            args.model_name, torch_dtype=torch.bfloat16 if args.use_bf16 else None)
    else:
        model_instance = AutoModelForCausalLM.from_pretrained(
            args.model_name, torch_dtype=torch.bfloat16 if args.use_bf16 else None)
    is_chat_model = True if args.model_name in CHAT_MODELS else False
    model_instance = model_instance.eval()
    model_instance.to(device)

    if tokenizer.unk_token == None and tokenizer.pad_token == None:
        # raw llama3
        print("adding a special padding token...")
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        need_resize = True
    else:
        need_resize = False
    if need_resize:
        model_instance.resize_token_embeddings(len(tokenizer))

    prefix_length = 1 # prefix is default to 1 for all models due to theBOS token.
    if is_chat_model:
        prefix_length = get_prefix_length(tokenizer)
        logger.warning(f"Chat model prefix length: {prefix_length}")

    state = load_state(dump_dir, rank)
    last_concept_id = state.get("last_concept_id", None) if state else None
    logger.warning(f"Rank {rank} last concept_id processed: {last_concept_id}")

    # Run training for assigned concept_ids
    # logger.warning(metadata)

    for concept_id, concept_df in my_df_list:
        concept_id = int(concept_id)
        if last_concept_id is not None and concept_id <= last_concept_id:
            logger.warning(f"Rank {rank} skipping concept_id {concept_id} because it is already processed")
            continue
        logger.warning(f"Training models for concept_id {concept_id} on rank {rank}")
        for model_name in sorted(args.models.keys()):
            concept = metadata[concept_id]["concept"]
            logger.warning(f"Training {model_name} with concept {concept}")
            benchmark_model = getattr(axbench, model_name)(
                model_instance, tokenizer, layer=args.layer,
                training_args=args.models[model_name],
                lm_model_name=args.model_name,
                device=device, seed=args.seed, use_wandb=args.use_wandb
            )
            low_rank_dimension = args.models[model_name].low_rank_dimension \
                if args.models[model_name].low_rank_dimension else 1
            benchmark_model.make_model(
                mode="train",
                embed_dim=model_instance.config.hidden_size,
                low_rank_dimension=low_rank_dimension,
                dtype=torch.bfloat16 if args.use_bf16 else None,
                intervention_type=args.models[model_name].intervention_type,
                concept_id=concept_id,
                sae_params=sae_params,
                metadata_path=metadata_path,
                dump_dir=dump_dir,
                model_params=args.models[model_name],
                dropout=args.models[model_name].dropout,
                intervention_positions_dropout=args.models[model_name].intervention_positions_dropout,
                preference_pairs=args.models[model_name].preference_pairs,
            )
            if model_name not in {"LoReFT", "LoRA", "SFT", "BoW"} and args.use_bf16:
                benchmark_model.ax.to(torch.bfloat16)
            kwargs = {
                "prefix_length": prefix_length,
                "positions": args.models[model_name].intervention_positions,
                "exclude_bos": args.models[model_name].exclude_bos,
                "metadata_path": metadata_path,
                "use_dpo_loss": args.use_dpo_loss,
                "logging_metadata": {
                    "concept_id": concept_id,
                    "model_name": model_name,
                    "layer": args.layer,
                },
                "wandb_project": args.wandb_project,
                "wandb_name": args.wandb_name,
                "negative_only": args.models[model_name].negative_only,
                "preference_pairs": args.models[model_name].preference_pairs,
                "steering_prompt_type": args.models[model_name].steering_prompt_type,
                "substraction_type": args.models[model_name].substraction_type,
            }
            prepared_df = concept_df.copy()
            prepared_df = prepare_df(
                prepared_df, negative_df, concept, metadata[concept_id], tokenizer, 
                binarize=args.models[model_name].binarize_dataset, 
                train_on_negative=args.models[model_name].train_on_negative,
                use_dpo_loss=args.use_dpo_loss,
                is_chat_model=is_chat_model,
                output_length=int(args.output_length),
                model_name=args.model_name,
                max_num_of_examples=args.max_num_of_examples,
                steering_prompt_type=args.models[model_name].steering_prompt_type,
                keep_orig_axbench_format=generate_args.keep_orig_axbench_format,
            )
            benchmark_model.train(prepared_df, **kwargs)
            benchmark_model.save(dump_dir, model_name=f"rank_{rank}_{model_name}")
            if model_name == "SFT":
                # we need to reload the original model after SFT.
                if args.use_bf16:
                    logger.warning(f"Using bfloat16 for model {args.model_name}")
                model_instance = AutoModelForCausalLM.from_pretrained(
                    args.model_name, torch_dtype=torch.bfloat16 if args.use_bf16 else None)
                is_chat_model = True if args.model_name in CHAT_MODELS else False
                model_instance = model_instance.eval()
                model_instance.to(device)
            if model_name == "LoRA":
                model_instance = benchmark_model.ax_model.unload()
            logger.warning(f"Saved weights and biases for model {model_name} on rank {rank}")
            # Clean up
            del benchmark_model
            torch.cuda.empty_cache()
        # After processing, save state
        current_state = {'last_concept_id': concept_id}
        save_state(dump_dir, current_state, metadata[concept_id], rank)

    # Synchronize all processes
    dist.barrier()

    # Rank 0 merges results
    if rank == 0:
        logger.warning("Rank 0 is merging results.")

        # Merging metadata
        metadata_entries = []
        for r in range(world_size):
            metadata_path = os.path.join(dump_dir, f"rank_{r}_{METADATA_FILE}")
            with open(metadata_path, "r") as f:
                for line in f:
                    metadata_entry = json.loads(line)
                    metadata_entries.append(metadata_entry)
        metadata_path = os.path.join(dump_dir, METADATA_FILE)
        with open(metadata_path, "a") as f:
            for metadata_entry in metadata_entries:
                f.write(json.dumps(metadata_entry) + "\n")

        # Save other config
        config = {"model_name": args.model_name,
                "layer": args.layer,
                "component": args.component}
        config_path = dump_dir / CONFIG_FILE
        with open(config_path, 'w') as f:
            json.dump(config, f)

        for model_name in sorted(args.models.keys()):
            # merge pruned SAEs
            sae_files = [dump_dir / f"rank_{r}_{model_name}.pt" for r in range(world_size)]
            sae_files_existing = [f for f in sae_files if f.exists()]
            if not sae_files_existing:
                logger.warning(f"No SAE files found for model {model_name}. Skipping.")
            else:
                sae_weights = [torch.load(f) for f in sae_files_existing]
                combined_sae_params = {
                    "b_dec": sae_weights[0]["b_dec"],
                    "W_dec": [],
                    "W_enc": [],
                    "b_enc": [],
                    "threshold": [],
                }
                for sae_weight in sae_weights:
                    combined_sae_params["W_dec"].append(sae_weight["W_dec"])
                    combined_sae_params["W_enc"].append(sae_weight["W_enc"])
                    combined_sae_params["b_enc"].append(sae_weight["b_enc"])
                    combined_sae_params["threshold"].append(sae_weight["threshold"])
                for k, v in combined_sae_params.items():
                    if k == "b_dec":
                        continue
                    if k == "W_enc":
                        combined_sae_params[k] = torch.cat(v, dim=1)
                    else:
                        combined_sae_params[k] = torch.cat(v, dim=0)
                torch.save(combined_sae_params, dump_dir / f"{model_name}.pt")
                logger.warning(f"Saved merged SAE weights for model {model_name}")
            
            # merge top features
            top_features_files = [dump_dir / f"rank_{r}_{model_name}_top_features.json" for r in range(world_size)]
            top_features_files_existing = [f for f in top_features_files if f.exists()]
            if not top_features_files_existing:
                logger.warning(f"No top features files found for model {model_name}. Skipping.")
            else:
                combined_top_features = []
                for top_feature_file in top_features_files:
                    with open(top_feature_file, "r") as f:
                        top_feature = json.load(f)
                        combined_top_features.extend(top_feature)
                with open(dump_dir / f"{model_name}_top_features.json", "w") as f:
                    json.dump(combined_top_features, f)
                logger.warning(f"Saved merged top features for model {model_name}")

            # Collect per-rank weight and bias files
            weight_files = [dump_dir / f"rank_{r}_{model_name}_weight.pt" for r in range(world_size)]
            bias_files = [dump_dir / f"rank_{r}_{model_name}_bias.pt" for r in range(world_size)]

            # Check if files exist
            weight_files_existing = [f for f in weight_files if f.exists()]
            bias_files_existing = [f for f in bias_files if f.exists()]

            if not weight_files_existing or not bias_files_existing:
                logger.warning(f"No weight or bias files found for model {model_name}. Skipping.")
                continue

            # Load weights and biases
            weights = [torch.load(f) for f in weight_files_existing]
            biases = [torch.load(f) for f in bias_files_existing]

            # Concatenate weights and biases
            if isinstance(weights[0], dict):
                merged_weight = {}
                for key in weights[0].keys():
                    weight_tensors = [w[key] for w in weights]
                    merged_weight[key] = torch.cat(weight_tensors, dim=0)
            else:
                merged_weight = torch.cat(weights, dim=0)

            # Handle dictionary biases
            if isinstance(biases[0], dict):
                merged_bias = {}
                for key in biases[0].keys():
                    bias_tensors = [b[key] for b in biases]
                    merged_bias[key] = torch.cat(bias_tensors, dim=0)
            else:
                merged_bias = torch.cat(biases, dim=0)

            # Save merged weight and bias files
            weight_file = dump_dir / f"{model_name}_weight.pt"
            bias_file = dump_dir / f"{model_name}_bias.pt"
            torch.save(merged_weight, weight_file)
            torch.save(merged_bias, bias_file)
            logger.warning(f"Saved merged weights and biases for model {model_name}")

            # Optionally delete per-rank files
            for f in weight_files_existing + bias_files_existing:
                try:
                    f.unlink()
                    logger.warning(f"Deleted file {f.name}")
                except Exception as e:
                    logger.error(f"Error deleting file {f.name}: {e}")

    # Finalize the process group
    dist.destroy_process_group()

    # Remove handlers to prevent duplication if the script is run multiple times
    logger.removeHandler(console_handler)
    # If file_handler is used, remove it as well
    # logger.removeHandler(file_handler)


if __name__ == "__main__":
    main()

