import pickle

from datasets import load_dataset,Dataset
from trl import SFTConfig, SFTTrainer
from transformers import Trainer

import json
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainerCallback, TrainerState, TrainerControl
from datasets import DatasetDict, Dataset
from tqdm import tqdm
from trl import SFTConfig, SFTTrainer
import re
import argparse
import warnings
from accelerate import Accelerator
from accelerate.utils import gather_object
import os
import torch.distributed as dist
from datetime import timedelta
import time
from transformers import DataCollatorWithPadding

# Ignore all warnings
warnings.filterwarnings("ignore")



class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, num_items_in_batch):

        labels = inputs.get("input_ids")
        loss_mask = inputs.pop("loss_mask", None)

        outputs = model(**inputs)
        # Free unused memory after forward pass

        logits = outputs.get("logits")
        logits = logits.float()


        if loss_mask is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous().long()
            shift_loss_mask = loss_mask[..., 1:].contiguous().float()  # Apply the same shift to the loss_mask

            loss_fct = torch.nn.CrossEntropyLoss(reduction='none')
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            loss = loss.view(shift_labels.size()) * shift_loss_mask
            loss = loss.sum() / (shift_loss_mask.sum() + 1e-5)
        else:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            loss_fct = torch.nn.CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

        return loss




class CustomDataCollator:
    def __init__(self, tokenizer, loss_start_token="Thought:", max_length=8192):
        self.tokenizer = tokenizer
        self.loss_start_token = loss_start_token
        self.max_length = max_length

    def __call__(self, features):
        # Extract input_ids from features
        input_ids = [torch.tensor(feature["input_ids"], dtype=torch.long) for feature in features]

        # Find the max sequence length in the batch
        max_len = max(len(ids) for ids in input_ids)

        # Pad input_ids manually to the same length (left padding)
        padded_input_ids = []
        padding_lengthes = []
        for ids in input_ids:
            padding_length = max_len - len(ids)
            padding_lengthes.append(padding_length)
            padded_ids = torch.cat([torch.full((padding_length,), self.tokenizer.pad_token_id, dtype=torch.long), ids])
            padded_input_ids.append(padded_ids)
        padded_input_ids = torch.stack(padded_input_ids)

        # Create loss masks and attention masks
        loss_masks = []
        attention_masks = []
        thought_token_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(self.tokenizer.tokenize(self.loss_start_token)), dtype=torch.long)

        for index, ids in enumerate(padded_input_ids):
            # Find where the "Thought:" token starts
            sliding_windows = ids.unfold(0, thought_token_ids.size(0), 1)
            match_matrix = (sliding_windows == thought_token_ids).all(dim=-1)
            match_indices = torch.nonzero(match_matrix, as_tuple=True)

            # Create a mask initialized to 0
            loss_mask = torch.zeros_like(ids, dtype=torch.float)
            # Initialize attention_mask, 1 for actual tokens, 0 for padding
            attention_mask = torch.ones_like(ids, dtype=torch.float)

            # Padding should be ignored (set to 0 in both loss_mask and attention_mask)
            attention_mask[:padding_lengthes[index]] = 0

            # If the "Thought:" token is found, mask the tokens before its occurrence
            if match_indices[0].numel() > 0:
                thought_start_idx = match_indices[0][0].item()
                loss_mask[thought_start_idx:] = 1

            loss_masks.append(loss_mask)
            attention_masks.append(attention_mask)

        padded_loss_masks = torch.stack(loss_masks)
        padded_attention_masks = torch.stack(attention_masks)

        # Prepare the final batch dictionary
        batch = {
            "input_ids": padded_input_ids,
            "loss_mask": padded_loss_masks,
            "attention_mask": padded_attention_masks  # Added attention_mask to the batch
        }

        return batch



def preprocess_data(data, tokenizer, max_seq_length):
    input_ids_list = []

    # for debug
    dataset = [tokenizer.apply_chat_template(sample['messages'], tokenize=False) for sample in data]

    for input_text in tqdm(dataset, desc="tokenizing"):
        tokenized_inputs = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=max_seq_length)
        input_ids = tokenized_inputs['input_ids']

        input_ids_list.append(input_ids.numpy()[0].tolist())

    df = pd.DataFrame({
        'input_ids': input_ids_list,
    })

    dataset = Dataset.from_pandas(df)
    return dataset



# ==========================
max_seq_length = 8192
model_path = f'/share_data/data1/models/Qwen/Qwen2.5-7B-Instruct'
version = 2
experiment_name = f'WorkBench_SFT_v{version}'
print(experiment_name)
output_path = f'metaflow_neurips/output/{experiment_name}'


SFT_data_path = f'./input/SFT_data_v{version}.pkl'


with open(SFT_data_path, 'rb') as fp:
    dataset = pickle.load(fp)

tokenizer = AutoTokenizer.from_pretrained(model_path)
dataset = preprocess_data(dataset, tokenizer, max_seq_length)
# =========================
use_flash_attention = True
device = torch.cuda.current_device() if torch.cuda.is_available() else "cpu"



model = AutoModelForCausalLM.from_pretrained(
    model_path,
    torch_dtype=torch.bfloat16,
    use_cache=False,
    attn_implementation="flash_attention_2",
)

# ==========================


if os.getenv('PYCHARM_HOSTED') != '1':
    dist.init_process_group(backend='nccl', timeout=timedelta(hours=6))


# Initialize the Accelerator
accelerator = Accelerator(mixed_precision='bf16')


os.environ['TORCH_NCCL_BLOCKING_WAIT'] = '1'
os.environ['TORCH_NCCL_ASYNC_ERROR_HANDLING'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = "False"


# dataset = Dataset.from_list(dataset)
training_args = SFTConfig(
    max_seq_length=max_seq_length,
    output_dir=output_path,
    learning_rate=1e-5,
    num_train_epochs=5,
    bf16=True,
    logging_steps=1,
    gradient_checkpointing=True,
    save_strategy='epoch',
    per_device_train_batch_size=1,
    gradient_accumulation_steps=2,
    max_grad_norm=0.3
)



loss_start_token = "<|im_start|>assistant"
data_collator = CustomDataCollator(tokenizer=tokenizer, loss_start_token=loss_start_token, max_length=max_seq_length)

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=tokenizer,
    data_collator=data_collator  # dynamic padding
)

trainer.train()