import os
# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"  # del
# os.environ["CUDA_VISIBLE_DEVICES"] = "4,5"
import numpy as np
import torch
import multiprocessing
from functools import partial
from datasets import Dataset
from datasets import load_dataset, load_from_disk, concatenate_datasets
# from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from peft import LoraConfig, TaskType, get_peft_model
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    SchedulerType,
    Trainer,
    TrainingArguments,
    set_seed
)
from peft import AutoPeftModelForCausalLM
from trl import SFTConfig, SFTTrainer
import argparse

def find_sublists(main_list, sub_list):
    return [i+len(sub_list) for i in range(len(main_list)-len(sub_list)+1) if main_list[i:i+len(sub_list)] == sub_list]

def concatenate_sequences(dataset_chunk, tokenizer, max_length=1024, eos_token_id=None):
    eos_token_id = eos_token_id or tokenizer.eos_token_id  # Use provided eos_token or tokenizer default
      # Use provided eos_token or tokenizer default
    concatenated_input_ids = []
    concatenated_labels = []

    current_input_ids = []
    current_labels = []

    for text, label in zip(dataset_chunk['text'], dataset_chunk['label']):
        result = tokenizer(text, truncation=True, max_length=max_length, padding=False)
        
        # Mask out the text tokens by setting them to -100
        labels = result['input_ids']
        # Concatenate current sequence with eos_token, ensure no overflow
        combined_length = len(current_input_ids) + len(result['input_ids']) + 1  # +1 for eos_token

        if combined_length > max_length:
            # Padding to max_length
            len_input = len(current_input_ids)
            current_input_ids += [eos_token_id] * (max_length - len_input)
            current_labels += [-100] * (max_length - len_input)
            assert len(current_input_ids) == len(current_labels)
            concatenated_input_ids.append(current_input_ids)
            concatenated_labels.append(current_labels)

            # Start a new batch
            current_input_ids = []
            current_labels = []

        # Add current sequence to the batch
        current_input_ids.extend(result['input_ids'])
        current_labels.extend(labels)
        # assert len(current_input_ids) == len(current_labels)

        # Add eos_token at the end of each sequence
        current_input_ids.append(eos_token_id)
        current_labels.append(eos_token_id)

    # Handle any remaining sequences that didn't fill up to max_length
    if current_input_ids:
        len_input = len(current_input_ids)
        current_input_ids += [eos_token_id] * (max_length - len_input)
        current_labels += [-100] * (max_length - len_input)
        concatenated_input_ids.append(current_input_ids)
        concatenated_labels.append(current_labels)
        assert len(current_input_ids) == len(current_labels)

    return concatenated_input_ids, concatenated_labels

def process_dataset_in_chunks(dataset, tokenizer, num_chunks=8, max_length=1024, eos_token_id=None):
    # Split dataset into num_chunks parts
    chunk_size = len(dataset) // num_chunks
    dataset_chunks = [dataset[i*chunk_size:(i+1)*chunk_size] for i in range(num_chunks)]
    
    # Use multiprocessing to process each chunk in parallel
    concatenate_with_args = partial(concatenate_sequences, tokenizer=tokenizer, max_length=max_length, eos_token_id=eos_token_id)

    # Use multiprocessing to process each chunk in parallel
    with multiprocessing.Pool(num_chunks) as pool:
        results = pool.map(concatenate_with_args, dataset_chunks)
    # Combine the results from each chunk
    concatenated_input_ids = []
    concatenated_labels = []
    
    for chunk_result in results:
        concatenated_input_ids.extend(chunk_result[0])
        concatenated_labels.extend(chunk_result[1])
    
    return concatenated_input_ids, concatenated_labels

def main(args):
    set_seed(42)
    dataset = load_dataset('json', data_files='data/'+args.data_path+'.jsonl', split='train')

    # dataset = load_dataset('json', data_files="gemini_gen_dataset.json", split='train')
    # dataset = dataset.select(range(18000)).shuffle()

    tokenizer = AutoTokenizer.from_pretrained(args.model)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side  = 'left'
    if not os.path.exists('./model'):
        os.makedirs('./model')
    output_dir = f"./model/{args.model_name}-{args.data_path}-full/{args.lr}" if not args.add_mcq else f"./{args.model_name}-{args.data_path}_mix/{args.lr}"
    training_args = SFTConfig(
        output_dir=output_dir,
        learning_rate=args.lr,
        num_train_epochs=args.epochs,
        save_strategy="no",
        bf16=True,
        # evaluation_strategy = "epoch", #To calculate metrics per epoch
        # auto_find_batch_size=True,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.gradient_accumulation,
        report_to="tensorboard",
        warmup_ratio=args.warmup_ratio,
        weight_decay=args.weight_decay,
        lr_scheduler_type=SchedulerType.COSINE,
        logging_steps=10,
        # padding_free=True, 
        model_init_kwargs={"attn_implementation": "flash_attention_2"}
        # eval_accumulation_steps=5
        )

    if args.load_peft:
        print(f"load from {args.load_path}...")
        model = AutoPeftModelForCausalLM.from_pretrained(
            args.load_path,
            attn_implementation="flash_attention_2",
            torch_dtype=torch.bfloat16,
            # device_map="auto",
        )
        model.resize_token_embeddings(len(tokenizer))
        model = model.merge_and_unload()
    else:
        model = AutoModelForCausalLM.from_pretrained(
            args.model,
            attn_implementation="flash_attention_2",
            torch_dtype=torch.bfloat16,
            # device_map="auto",
        )

        model.resize_token_embeddings(len(tokenizer))

    # peft_config = LoraConfig(
    #             task_type=TaskType.CAUSAL_LM,lora_alpha=32, r=16, lora_dropout=0.05, target_modules=['q_proj','v_proj','k_proj','o_proj','gate_proj','down_proj','up_proj']
    #         )

    # model = get_peft_model(model, peft_config)
    model.enable_input_require_grads()
    
    model.gradient_checkpointing_enable()
    model.train()

    # # Define tokenizer and eos_token_id
    # def tokenize(element):
    #     return tokenizer(
    #         element["text"],
    #         truncation=True,
    #         max_length=1024,
    #     )
    # dataset = dataset.filter(lambda x: len(x['text'].strip())>0)
    
    # tokenized_dataset = dataset.map(
    #     tokenize,
    #     batched=True,
    #     num_proc=20,
    # )

    # # Save or process the tokenized_dataset as needed
    # print(tokenized_dataset)

    # data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

    # lr = args.lr
    

    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
    )
    trainer.train()
    # if trainer.is_fsdp_enabled:
    #     trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")

    trainer.save_model(f"{output_dir}/saved_model_{args.epochs}_epochs")
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, default='')
    parser.add_argument("--data_path", type=str, default='')
    parser.add_argument("--model_name", type=str, default='llama2-7B')
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--batch_size", type=int, default=8)
    parser.add_argument("--gradient_accumulation", type=int, default=1)
    parser.add_argument("--reph_times", type=int, default=1)
    parser.add_argument("--add_mcq",action="store_true")
    parser.add_argument("--gradient_norm", type=float, default=1)
    parser.add_argument("--warmup_ratio", type=float, default=0.05)
    parser.add_argument("--weight_decay", type=float,default=0.01)
    parser.add_argument("--seed",type=int,default=42)
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--load_peft",action="store_true")
    parser.add_argument("--load_path", type=str)
    
    args = parser.parse_args()
    main(args)