import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from datasets import Dataset
from trl import SFTTrainer
import pandas as pd
import json
import glob
from collections import defaultdict
import wandb
from typing import Dict, Any, Literal

ModelCheckpoint = Literal[
    'mistralai/Mistral-7B-Instruct-v0.1',
    'HuggingFaceH4/mistral-7b-sft-beta'
]

def load_chatgpt_data(folder_path):
    jsonl_files = glob.glob(os.path.join(folder_path, "**/*.jsonl"), recursive=True)
    folder_messages = defaultdict(list)

    for file_path in jsonl_files:
        folder_name = os.path.dirname(file_path)
        file_name = os.path.basename(file_path)
        message_number = int(file_name.split('_')[1].split('.')[0])
        
        with open(file_path, 'r') as file:
            data = json.load(file)
            folder_messages[folder_name].append((message_number, data['messages']))

    all_message_lists = []
    for folder, messages in folder_messages.items():
        sorted_messages = [m[1] for m in sorted(messages, key=lambda x: x[0])]
        all_message_lists.extend(sorted_messages)

    df = pd.DataFrame({'messages': all_message_lists})
    # df.to_csv(output_file, index=False)
    return df

def train_model(
    checkpoint: ModelCheckpoint = 'HuggingFaceH4/mistral-7b-sft-beta',
    enableTest: bool = False,
    bnb_params: Dict[str, Any] | None = None,
    wandb_config: Dict[str, Any] | None = None,
    peft_config: Dict[str, Any] | None = None,
    trainer_config: Dict[str, Any] | None = None
) -> None:

    train_folder_path = os.path.join('..', '..', 'data', 'finetune_data', 'chatgpt_data', 'train')
    if enableTest:
        test_folder_path = os.path.join('..', '..', 'data', 'finetune_data', 'chatgpt_data', 'test')

    train_df = load_chatgpt_data(train_folder_path)
    if enableTest:
        test_df = load_chatgpt_data(test_folder_path)

    bnb_config = BitsAndBytesConfig(**bnb_params)
    model = AutoModelForCausalLM.from_pretrained(
        checkpoint,
        quantization_config=bnb_config,
        device_map={"": 0},
        # Training-specific configurations
        use_cache=False,  # Disable KV cache for training efficiency
        torch_dtype=torch.float16,  # Use mixed precision training
    )
    
    # Configure model for training
    model.config.pretraining_tp = 1  # Set tensor parallelism for training
    model.gradient_checkpointing_enable()  # Enable gradient checkpointing to save memory

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.add_eos_token = True
    train_df['template_formatted_conversation_turns'] = train_df['messages'].apply(lambda x: tokenizer.apply_chat_template(x, tokenize=False))
    if enableTest:
        test_df['template_formatted_conversation_turns'] = test_df['messages'].apply(lambda x: tokenizer.apply_chat_template(x, tokenize=False))

    # Use default wandb config if none provided
    default_wandb_config = {
        "architecture": "Mistral-7B-SFT",
        "dataset": "Internal Testing + SONA",
    }
    wandb_config = wandb_config or default_wandb_config

    wandb.init(
        project="mistral_finetuning",
        config=wandb_config
    )

    model = prepare_model_for_kbit_training(model)
    
    # Default PEFT config
    default_peft_config = {
        "r": 16,
        "lora_alpha": 32,
        "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"],
        "lora_dropout": 0.05,
        "bias": "all",
        "task_type": "CAUSAL_LM"
    }
    
    # Use provided peft_config if available, otherwise use default
    peft_config = peft_config or default_peft_config
    lora_config = LoraConfig(**peft_config)
    model = get_peft_model(model, lora_config)

    train_dataset = Dataset.from_pandas(train_df[['template_formatted_conversation_turns']])
    if enableTest:
        test_dataset = Dataset.from_pandas(test_df[['template_formatted_conversation_turns']])


    training_args = TrainingArguments(**trainer_config)

    def tokenize_function(example):
        result = tokenizer(
            example["template_formatted_conversation_turns"],
            truncation=True,
            padding="max_length",
            max_length=1024,
            return_tensors="pt"
        )
        # Set up causal language modeling: for each position, the model will try to predict the next token
        # Labels are identical to inputs since each token serves as the target for the previous position
        result["labels"] = result["input_ids"].copy()
        return result

    # Apply tokenization to the dataset
    tokenized_train_dataset = train_dataset.map(
        tokenize_function,
        remove_columns=train_dataset.column_names,
        desc="Tokenizing dataset"
    )
    if enableTest:
        tokenized_test_dataset = test_dataset.map(
            tokenize_function,
            remove_columns=test_dataset.column_names,
            desc="Tokenizing dataset"
        )

    # Initialize trainer with data collator
    def data_collator(features):
        return tokenizer.pad(
            features,
            padding=True,
            return_tensors="pt",
        )

    trainer_kwargs = {
        "model": model,
        "train_dataset": tokenized_train_dataset,
        "tokenizer": tokenizer,
        "args": training_args,
        "data_collator": data_collator,
        "max_seq_length": 1024
    }

    if enableTest:
        trainer_kwargs["eval_dataset"] = tokenized_test_dataset

    trainer = SFTTrainer(**trainer_kwargs)
    trainer.train()

    trainer.model.save_pretrained(trainer_config.output_dir)
    
if __name__ == "__main__":
    bnb_params = {
        "load_in_4bit": True,
        "bnb_4bit_compute_dtype": torch.float16
    }
    wandb_config = {
        "architecture": "Mistral-7B-SFT",
        "dataset": "Internal Testing + SONA",
    }
    peft_config = {
        "r": 16,
        "lora_alpha": 32,
        "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"],
        "lora_dropout": 0.05,
        "bias": "all",
        "task_type": "CAUSAL_LM"
    }
    trainer_config = {
        "per_device_train_batch_size": 2,      # Increased from 1
        "gradient_accumulation_steps": 4,       # Reduced from 8
        "num_train_epochs": 5,
        "learning_rate": 1e-4,                 # Increased from 2e-5
        "max_grad_norm": 0.3,                  # Reduced from 1.0 for tighter gradient clipping
        "fp16": False,
        "logging_steps": 2,                    # More frequent logging
        "save_strategy": "epoch",
        "output_dir": "mistral-7b-rlhf-25jan-experimentSplit",
        "optim": "paged_adamw_8bit",
        "lr_scheduler_type": "cosine",      
        "warmup_ratio": 0.03,                  # Reduced from 0.1
        "weight_decay": 0.05,                  # Increased from 0.01
        "gradient_checkpointing": True,
        "report_to": "wandb",
        "remove_unused_columns": False,
        "dataloader_pin_memory": True
    }
    train_model(bnb_params=bnb_params, wandb_config=wandb_config, peft_config=peft_config, trainer_config=trainer_config)