import os
import ast 

import json
from datasets import load_dataset
import datasets
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments
)
from trl import SFTConfig, SFTTrainer, DataCollatorForCompletionOnlyLM
from peft import LoraConfig, PeftConfig, PeftModelForCausalLM, get_peft_model
from accelerate import infer_auto_device_map, PartialState
import pandas as pd 


import torch
from torch.utils.data import Dataset
import numpy as np
import argparse
import random

def find_all_linear_names(args, model):
    cls = torch.nn.Linear
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')

    return list(lora_module_names)

def convert_to_conv_format(row, remove_reasoning):
    instruction = row['instruction']
    user_message = row['input']
    assistant_response = row['output']
    prompt = [{"role": "system", "content": instruction}, 
              {"role": "user", "content": user_message}]

    if remove_reasoning:
        assistant_response = assistant_response.split('</think>')[1].strip()
        completion =  [{"role": "assistant", "content": assistant_response}]
    else:
        completion =  [{"role": "assistant", "content": assistant_response}]
    return {'prompt': prompt, 'completion': completion}

if __name__ == "__main__":
        parser = argparse.ArgumentParser()
        parser.add_argument("--base_model", type=str, required=True,
                            help="Base model to fine tune.")
        parser.add_argument("--sft_data_path", type=str, required=True,
                            help="Dataset path")        
        parser.add_argument("--output_model_path", type=str, required=True,
                            help="Path for trained model and checkpoints.")
        parser.add_argument("--logging_steps", default=20, type=int, required=False,
                            help="Logging steps parameter.")
        parser.add_argument("--per_device_train_batch_size", default=16, type=int, required=False,
                            help="Per device batch size parameter.")
        parser.add_argument("--gradient_accumulation_steps", default=8, type=int, required=False,
                            help="Gradient accumulation parameter.")
        parser.add_argument("--learning_rate", default=2e-4, type=float, required=False,
                            help="Learning rate parameter.")
        parser.add_argument("--lora_r", default=32, type=int, required=False,
                            help="Lora rank")
        parser.add_argument("--lora_alpha", default=64, type=int, required=False,
                            help="Lora alpha")
        parser.add_argument("--epochs", default=1, type=int, required=False,
                            help="Number of epochs to train")
        parser.add_argument("--remove_reasoning",  action='store_true',
                    help="Whether to remove reasoning traces or not") 
        args = parser.parse_args()

        ####################################
        # Dataset Initialization
        ####################################

        tokenizer = AutoTokenizer.from_pretrained(args.base_model) 

        data = []
        with open(args.sft_data_path, encoding='utf8') as fIn:
            json_data = json.load(fIn)  # Load the entire JSON file
            for line in json_data:
                messages = convert_to_conv_format(line, remove_reasoning=args.remove_reasoning)
                data.append(messages)  
        
        print("EXAMPLE DATA POINT:")
        print("###############################")
        print(data[0])
        print("###############################")
        train_data_hf = datasets.Dataset.from_pandas(pd.DataFrame(data=data))

        ####################################
        # Model Initialization
        ####################################        
        model = AutoModelForCausalLM.from_pretrained(
                args.base_model, 
                quantization_config=None,
                torch_dtype='auto',
                device_map="auto",
                trust_remote_code=True,
                )        
  
        lora_r = args.lora_r 
        lora_alpha = args.lora_alpha 
        target_modules = find_all_linear_names(args, model)
        peft_config = LoraConfig(r=lora_r,
                                 lora_alpha=lora_alpha,
                                 lora_dropout=0.1,  
                                 task_type="CAUSAL_LM",
                                 target_modules=target_modules,
                                 inference_mode=False,
                                )
        model = get_peft_model(model, peft_config)        
        
        model.train()

        # Create saved model name 
        train_data_type = os.path.basename(args.sft_data_path)

        if args.remove_reasoning:
            model_save_name = \
                f'{os.path.basename(args.base_model)}_standardrr'
        else:
            model_save_name = \
                f'{os.path.basename(args.base_model)}_reasonrr'

        
        model_output_dir = os.path.join(args.output_model_path, model_save_name)
        
        ####################################
        # Training
        ####################################      
        effective_batch_size = args.per_device_train_batch_size*args.gradient_accumulation_steps
        steps_per_epoch = len(train_data_hf) / effective_batch_size
        save_steps = int(steps_per_epoch // 5)
        print(f"Saving model every {save_steps} steps", flush=True)
        train_args = TrainingArguments(
            output_dir=model_output_dir,
            do_train=True,
            save_strategy='steps',
            save_steps=save_steps,
            logging_steps=args.logging_steps,
            per_device_train_batch_size=args.per_device_train_batch_size,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            learning_rate=args.learning_rate,
            num_train_epochs=args.epochs,
            seed=23,
            disable_tqdm=False,
            dataloader_pin_memory=False,
        )

        trainer = SFTTrainer(
            model=model,
            args=train_args,
            train_dataset=train_data_hf,
        )

        trainer.train() 
        trainer.save_model(model_output_dir)