import os
from dataclasses import dataclass, field, asdict
from typing import Optional
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
from datasets import load_dataset, concatenate_datasets, DatasetDict
import transformers
import trl
from transformers.data.data_collator import DataCollatorMixin
from typing import TYPE_CHECKING, Any, Callable, Optional, Union, List
from dataclasses import asdict
import torch


@dataclass
class SoFTConfig:
    N_num_sot_tokens: int = field(default=4)
    L_first_matching_tokens: int = field(default=500)
    precomputed_num_update_steps_per_epoch_list: List[int] = field(default_factory=list)
    instantaneous_pad_batch_size: int = field(default=1) # 1 for safe to guarantee at least full grad accumu
    bmsft_phase_train: bool = field(default=True)
    qwen_liger_version: str = field(default=None)
    debug_randomized_matching: bool = field(default=True)

@dataclass
class TrainingConfig:
    model_name: str = field(default="Qwen/Qwen2.5-32B-Instruct")
    block_size: int = field(default=32768)
    wandb_project: Optional[str] = field(default="s1")
    wandb_entity: Optional[str] = field(default="")
    train_file_path: Optional[str] = field(default='')
    dagger: bool = field(default=False)

    def __post_init__(self):
        os.environ['WANDB_PROJECT'] = self.wandb_project
        os.environ['WANDB_ENTITY'] = self.wandb_entity



@dataclass
class DataCollatorForScalableSoTLanguageModelingPadding(DataCollatorMixin):
    pad_token_id: int
    return_position_ids: bool = True #=True
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"
    def torch_call(self, examples: list[Union[list[int], Any, dict[str, Any]]]) -> dict[str, Any]:
        # import ipdb; ipdb.set_trace()
        expanded_input_ids = []
        expanded_labels = []
        expanded_attention_masks = []
        expanded_position_ids = []

        expanded_matching_input_ids = []
        expanded_matching_labels = []
        expanded_matching_attention_masks = []
        expanded_matching_position_ids = []
        for m_i, example in enumerate(examples):
            input_ids_list = example["input_ids"] # NUM_SOT_TOKENS*NUM_TARGETS
            len_pre_skip = example["len_pre_skip"] #len_pre_skip=173
            len_assistant_skip = example["len_assistant_skip"]
            # TODO add "assistant_tag_end_id", which is also just a scalar
            # len_in_matching=[1173,1173,1173,1173], w/ L_first_matching_tokens=1000
            # len_in_matching=[8856,2413,13212,4451], w/ L_first_matching_tokens=32768

            len_in_matching = example["len_in_matching"] # NOTE only store num_targets, NUM_TARGETS
            num_targets = len(len_in_matching)
            mi_input_ids = []
            mi_labels = []
            mi_attention_masks = []
            mi_position_ids = []

            mi_matching_input_ids = []
            mi_matching_labels = []
            mi_matching_attention_masks = []
            mi_matching_position_ids = []
            for seq_i, seq in enumerate(input_ids_list):
                seq_tensor = torch.tensor(seq, dtype=torch.long) #seq_tensor.shape=[9471]
                # k
                mi_input_ids.append(seq_tensor)
                labels = seq_tensor.clone() #labels.shape=[9471]
                # import rpdb
                # import torch.distributed as dist
                # port = 4444+dist.get_rank()
                # print(f"Process {dist.get_rank()} waiting for debugger on port {port}")
                # rpdb.set_trace(port=port)
                # NOTE newly added to do completion only matching
                # NOTE SJ len_assistant_skip=168, len_pre_skip=173, 
                labels[:len_assistant_skip]= -100 #ignore until "<|im_start|>assistant\n" then "<|im_start|>think..."
                mi_labels.append(labels)
                _attention_masks = torch.ones_like(seq_tensor, dtype=torch.long)
                mi_attention_masks.append(_attention_masks)
                _position_ids = torch.arange(len(seq_tensor), dtype=torch.long)
                mi_position_ids.append(_position_ids)
                # len_pre_skip (before and including think)  + len_in_matching = end of thinking
                target_idx = seq_i % num_targets
                sub_len = len_in_matching[target_idx]
                matching_seq_tensor = seq_tensor[:sub_len]
                matching_labels = (matching_seq_tensor).clone()
                matching_labels[:len_pre_skip] = -100

                mi_matching_input_ids.append(matching_seq_tensor)
                mi_matching_labels.append(matching_labels)
                mi_matching_attention_masks.append(_attention_masks[:sub_len])
                mi_matching_position_ids.append(_position_ids[:sub_len])
            expanded_input_ids.append(mi_input_ids)
            expanded_labels.append(mi_labels)
            expanded_attention_masks.append(mi_attention_masks)
            expanded_position_ids.append(mi_position_ids)

            expanded_matching_input_ids.append(mi_matching_input_ids)
            expanded_matching_labels.append(mi_matching_labels)
            expanded_matching_attention_masks.append(mi_matching_attention_masks)
            expanded_matching_position_ids.append(mi_matching_position_ids)
            # import rpdb
            # port = 4444+dist.get_rank()
            # print(f"Process {port} waiting for debugger on port {port}")
            # rpdb.set_trace(port=port)
        ### No pad which includes concatenate operation compared to the original ###
        output = {}
        output["input_ids"] = expanded_input_ids
        output["attention_mask"] = expanded_attention_masks
        output["labels"] = expanded_labels
        output["matching_input_ids"] = expanded_matching_input_ids
        output["matching_attention_mask"] = expanded_matching_attention_masks
        output["matching_labels"] = expanded_matching_labels
        if self.return_position_ids:
            output["position_ids"] = expanded_position_ids
            output["matching_position_ids"] = expanded_matching_position_ids
        return output



def train():
    parser = transformers.HfArgumentParser((TrainingConfig, trl.SFTConfig, SoFTConfig))
    config, args, soft_args = parser.parse_args_into_dataclasses()
    # import torch.distributed as dist
    # import rpdb
    # port = 4444+dist.get_rank()
    # print(f"Process {dist.get_rank()} waiting for debugger on port {port}")
    # rpdb.set_trace(port=port)
    # copy the dataclass to dict
    for k, v in asdict(soft_args).items():
        setattr(args, k, v)   # now args has those attributes

    log_config = {**asdict(config), **asdict(args)}
    logging.info(f"Training config: {log_config}")

    # loading model
    kwargs = {}
    if "70B" in config.model_name:
        # Removed "low_cpu_mem_usage": True, for 70B, since by default we are in FSDP,
        # it's more efficient to do  "cpu_ram_efficient_loading": true, in fsdp_config.json
        kwargs = {"device_map": "auto", "torch_dtype": "auto",
                  "attn_implementation": "flash_attention_2", "use_cache": False}
        model = transformers.AutoModelForCausalLM.from_pretrained(config.model_name, **kwargs)
    else:
        model = transformers.AutoModelForCausalLM.from_pretrained(config.model_name)

    dataset = load_dataset(config.train_file_path)

    # from datasets import is_caching_enabled
    # print("caching enabled:", is_caching_enabled())
    # print("cache files now:", dataset.cache_files)  # after your first map this should NOT be empty
    # import rpdb
    # import torch.distributed as dist
    # port = 4444+dist.get_rank()
    # print(f"Process {dist.get_rank()} waiting for debugger on port {port}")
    # rpdb.set_trace(port=port)

    # setting up trainer
    tokenizer = transformers.AutoTokenizer.from_pretrained(config.model_name, use_fast=True)
    if "Llama" in config.model_name:
        instruction_template = "<|start_header_id|>user<|end_header_id|>"
        response_template = "<|start_header_id|>assistant<|end_header_id|>\n\n"
        # Use a token that is never used
        tokenizer.pad_token = "<|reserved_special_token_5|>"
    elif "Qwen" in config.model_name:
        instruction_template = "<|im_start|>user"
        response_template = "<|im_start|>assistant\n"
        # Use a token that is never used
        tokenizer.pad_token = "<|fim_pad|>"

    # Only compute loss over assistant responses
    # Verified that it precisely starts where the thinking tokens start and ends with the first pad token
    # via labels being set to -100
    # NOTE SJ: write my own collator here
    pad_token = tokenizer.pad_token
    # NOTE SJ: pad_token_id=151643
    pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)

    collator = DataCollatorForScalableSoTLanguageModelingPadding(
        pad_token_id=pad_token_id,
        return_position_ids=True,
    )
    args.dataset_text_field = 'text'
    args.max_seq_length = config.block_size
    model.config.qwen_liger_version = args.qwen_liger_version # NOTE SJ
    trainer = trl.ScalableSoFTTrainer(
        model,
        args=args,
        data_collator=collator,
        train_dataset=dataset['train'],
        eval_dataset=dataset['test'] if 'test' in dataset else dataset['train'],
        bmsft_phase_train=args.bmsft_phase_train,
    )

    trainer.train()
    trainer.save_model(output_dir=args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    trainer.accelerator.wait_for_everyone()


if __name__ == "__main__":
    train()


