# This is a sample Python script.
import os.path

# Press Shift+F10 to execute it or replace it with your code.
# Press Double Shift to search everywhere for classes, files, tool windows, actions, and settings.

from modeling_insnet_deepagg import InsNetForInsertionLM, insnet_configs
from torch.utils.data import Dataset
from torch.utils.data import SequentialSampler
import nltk
import torch
from IPython import embed
from transformers import Trainer, TrainingArguments, TrainerCallback, default_data_collator, PreTrainedTokenizerFast
import deepspeed

deepspeed.runtime.lr_schedules
import numpy as np
from torch.utils.data import Sampler

import numpy as np
import torch
from torch.utils.data import Sampler, DistributedSampler


class NumPyPCGSampler(Sampler):
    def __init__(self, data_source, seed=None):
        super().__init__(data_source)
        self.data_source = data_source
        self.seed = seed
        # Initialize the NumPy generator with PCG64DXSM
        self.rng = np.random.Generator(np.random.PCG64DXSM(seed=seed))

    def __iter__(self):
        # Generate shuffled indices using NumPy's PCG64DXSM
        indices = self.rng.permutation(len(self.data_source)).tolist()
        return iter(indices)

    def __len__(self):

        return len(self.data_source)

import math
class DistributedSamplerPCG64DXSM(DistributedSampler):
    def __iter__(self):
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            # use numpy's RNG PCG64DXSM instead of torch.randperm
            rng = np.random.Generator(np.random.PCG64DXSM(self.seed + self.epoch))
            indices = rng.permutation(len(self.dataset)).tolist()  # type: ignore[arg-type]
        else:
            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]

        if not self.drop_last:
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[: self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank : self.total_size : self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

def str2bool(st):
    return st.lower().count('t') > 0
from prepare_multisent_data import RandomPermutationDocDatasetWithBPE

class RandomPermutationDatasetWithBPE(Dataset):
    def __init__(self, tokenizer, hf_dataset, max_len=2048, autoregressive=False):
        self.data = hf_dataset
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.autoregressive = autoregressive

    def __len__(self):
        return self.data.__len__()

    def __getitem__(self, item):
        target = self.data[item]['text']
        tokens = self.tokenizer.tokenize(" " + target)

        segments_ij = np.zeros(shape=[2, self.max_len], dtype=np.int32)

        recorded_i = -1
        segments_num = 0

        for token_i in range(len(tokens)):
            token_decoded = self.tokenizer.convert_tokens_to_string(tokens[token_i:token_i+1])
            if not token_decoded.isalpha():
                if recorded_i != -1:
                    segments_ij[0, segments_num] = recorded_i
                    segments_ij[1, segments_num] = token_i
                    segments_num += 1
                recorded_i = token_i
        segments_ij[0, segments_num] = recorded_i
        segments_ij[1, segments_num] = len(tokens)
        segments_num += 1

        input_ids_clean = self.tokenizer.convert_tokens_to_ids(tokens)
        if self.autoregressive:
            permutation_ids_clean = np.concatenate(
                [np.arange(segments_ij[0, i], segments_ij[1, i]) for i in np.arange(segments_num)]
                # [np.arange(segments_ij[0, i], segments_ij[1, i]) for i in np.random.permutation(segments_num)]
            )
        else:
            permutation_ids_clean = np.concatenate(
                # [np.arange(segments_ij[0, i], segments_ij[1, i]) for i in np.arange(segments_num)]
                [np.arange(segments_ij[0, i], segments_ij[1, i]) for i in np.random.permutation(segments_num)]
            )
        input_ids = torch.ones(self.max_len, dtype=torch.long) * self.tokenizer.pad_token_id
        permutation_ids = -torch.ones(self.max_len, dtype=torch.long)

        input_ids[0: len(input_ids_clean)] = torch.tensor(input_ids_clean)
        permutation_ids[0: len(permutation_ids_clean)] = torch.tensor(permutation_ids_clean)

        return {
            "input_ids": input_ids,
            "permutation_ids": permutation_ids
        }


import datasets
import argparse

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--batch_size",
        default=2016,
        type=int,
        required=False,
        help="Continue the training or start from scratch.",
    )
    parser.add_argument(
        "--model_size",
        default="small",
        type=str,
        required=False,
        help="Determine the variant of the nado algorithm.",
    )
    parser.add_argument(
        "--mini_batch_size",
        default=18,
        type=int,
        required=False,
        help="Continue the training or start from scratch.",
    )
    parser.add_argument(
        "--local_rank",
        default=-1,
        type=int,
        required=False,
        help="Node id in DDP",
    )
    parser.add_argument(
        "--factorization",
        # default="autoregressive",
        default="insertion",
        type=str,
        required=False,
        help="Node id in DDP",
    )
    parser.add_argument(
        "--deepspeed",
        default=None,
        type=str,
        required=False,
        help="Deepspeed Config in DDP",
    )
    parser.add_argument(
        "--resumed",
        default=True ,
        type=str2bool,
        required=False,
        help="resumed training",
    )

    args = parser.parse_args()


    raw_datasets = [datasets.load_from_disk("/data2/slimpajama-20250201/%d" % i) for i in (list(range(0, 5 * 10000000 + 1, 10000000)))]
    raw_dataset = datasets.concatenate_datasets(raw_datasets)

    # raw_dataset = datasets.load_from_disk("./wikipedia2023")
    # embed()
    # exit()
    sampled_dataset = raw_dataset#.select(range(2048))
    checkpoint_name = "insnet-v2-%s" % args.model_size
    tokenizer = PreTrainedTokenizerFast(tokenizer_file="insnet_tokenizer.json")
    tokenizer.pad_token = "<|endoftext|>"
    tokenizer.eos_token = "<|endoftext|>"
    tokenizer.bos_token = "<|endoftext|>"

    # dataset = RandomPermutationDatasetWithBPE(tokenizer=tokenizer, hf_dataset=sampled_dataset, max_len=2048, autoregressive=(args.factorization == "autoregressive"))
    dataset = RandomPermutationDocDatasetWithBPE(tokenizer=tokenizer, hf_dataset=sampled_dataset, max_len=2048, autoregressive=(args.factorization == "autoregressive"), bidirectional_rate=0.25)
    dataset.vanilla_permutation = True
    dataset.shuffle_sentence = False

    config = insnet_configs["insnet-%s" % args.model_size]


    model = InsNetForInsertionLM(
        config=config,
    )
    ckpt_step = 160000
    model = InsNetForInsertionLM.from_pretrained("./checkpoints/insnet-v2-%s-%s-warmup/checkpoint-%d" % (args.model_size, args.factorization, ckpt_step))

    config.ffn_residual = "preln"
    config.position_encoding = "alibi"
    config.hidden_dropout = 0.0

    # ckpt_step = 1000
    # model = InsNetForInsertionLM.from_pretrained("./checkpoints/insnet-v2-%s-%s-warmup/checkpoint-%d" % (args.model_size, args.factorization, ckpt_step))

    model.config.pad_token = "<|endoftext|>"
    model.config.eos_token = "<|endoftext|>"
    model.config.bos_token = "<|endoftext|>"

    # model.term_projection.init_weights()
    # model._init_weights(model.term_output)
    # model._init_weights(model.pos_projection_l)
    # model._init_weights(model.pos_projection_r)
    # model._init_weights(model.pos_projection_g)
    # model.word_embeddings.requires_grad_(False)
    # dataset[10160873]]


    class CustomTrainer(Trainer):
        # def _get_train_sampler(self):
        #     return SequentialSampler(self.train_dataset)
        @property
        def tokenizer(self):
            return self.processing_class

        @tokenizer.setter
        def tokenizer(self, processing_class) -> None:
            self.processing_class = processing_class

        def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
            input_ids = inputs['input_ids']
            permutation_ids = inputs['permutation_ids']
            term_ids = inputs['term_ids']
            bidirectional = inputs['bidirectional']

            mask = (input_ids != self.tokenizer.pad_token_id).to(torch.long)
            efft_length = mask.sum(dim=-1)
            # if self.state.global_step < 1000:
            #     batch_maxlen = 126
            # elif self.state.global_step < 3000:
            #     if self.state.global_step == 1000:
            #         torch._dynamo.reset()
            #     batch_maxlen = 254
            # elif self.state.global_step < 7000:
            #     if self.state.global_step == 3000:
            #         torch._dynamo.reset()
            #     batch_maxlen = 510
            # else:
            #     if self.state.global_step == 7000:
            #         torch._dynamo.reset()
            batch_maxlen = 1022

            permutation_ids_trunc = torch.where(permutation_ids != -1, permutation_ids, 2047)[:,0:batch_maxlen]

            selected_ids = permutation_ids_trunc.sort(dim=-1)
            input_ids_trunc = input_ids.gather(dim=-1, index=selected_ids.values)
            permutation_ids_trunc = selected_ids.indices.argsort(dim=-1)
            permutation_ids_trunc = torch.where(
                permutation_ids_trunc != 2047,
                permutation_ids_trunc,
                -1
            )

            input_ids = input_ids_trunc.contiguous()
            permutation_ids = permutation_ids_trunc.contiguous()
            term_ids = term_ids[:, 0:batch_maxlen].contiguous()
            bidirectional = bidirectional[:, 0:batch_maxlen].contiguous()

            # mask = mask[:, 0:batch_maxlen].contiguous()

            model_output = model(
                input_ids=input_ids,
                permutation_ids=permutation_ids,
                term_ids=term_ids,
                bidirectional=bidirectional,
            )

            if self.state.is_local_process_zero:
                with torch.no_grad():
                    most_recent_output = {
                        "term_loss": model_output.termination_loss.mean(dim=0).cpu().detach().to(torch.float32),
                        "pos_loss": model_output.position_loss.sum(dim=-1).mean(dim=0).cpu().item(),
                        "tok_loss": model_output.token_loss.sum(dim=-1).mean(dim=0).cpu().item(),
                    }
                    if not hasattr(self.state, "most_recent_output"):
                        self.state.most_recent_output = most_recent_output
                    else:
                        if "tok_loss" in self.state.most_recent_output and \
                                "pos_loss" in self.state.most_recent_output and \
                                "term_loss" in self.state.most_recent_output:
                            self.state.most_recent_output = {
                                "term_loss": 0.01 * most_recent_output["term_loss"] + 0.99 * self.state.most_recent_output["term_loss"],
                                "pos_loss": 0.01 * most_recent_output["pos_loss"] + 0.99 * self.state.most_recent_output["pos_loss"],
                                "tok_loss": 0.01 * most_recent_output["tok_loss"] + 0.99 * self.state.most_recent_output["tok_loss"],
                            }
                        else:
                            self.state.most_recent_output = most_recent_output

            loss = model_output.loss.mean()

            return (loss.detach() + (loss - loss.detach()) / (batch_maxlen + 2.0))

    class LogCallback(TrainerCallback):
        def on_epoch_begin(self, args: TrainingArguments, state, control, **kwargs):
            if state.is_local_process_zero:
                if state.epoch < 0.01:
                    flog = args.output_dir if args.output_dir else args.overwrite_output_dir
                    flog = os.path.join(flog, "log")
                    open(flog, "w").close()

        def on_step_end(self, args: TrainingArguments, state, control, **kwargs):
            state.log_history = []
            # if hasattr(state, "last_position_loss"):
            #     if state.last_position_loss is not None:
            #         print(state.last_position_loss)
            #         state.last_position_loss = None

        def on_log(self, args: TrainingArguments, state, control, logs: dict = None,
                   **kwargs):
            state.log_history = []
            if 'grad_norm' in logs and type(logs['grad_norm']) is torch.Tensor:
                logs['grad_norm'] = logs['grad_norm'].item()

            if state.is_local_process_zero:
                try:
                    if 'grad_norm' in logs:
                        _logs = {
                            "epoch": logs["epoch"],
                            "total_loss": logs["loss"],
                            "LR": logs["learning_rate"],
                            "grad_norm": logs["grad_norm"],
                        }
                    else:
                        _logs = {
                            "epoch": logs["epoch"],
                            "total_loss": logs["loss"],
                            "LR": logs["learning_rate"],
                            # "grad_norm": logs["grad_norm"],
                        }
                except KeyError:
                    return
                logs.clear()
                try:
                    if np.isnan(state.most_recent_output["pos_loss"]).item() or np.isnan(state.most_recent_output["tok_loss"]).item():
                        print("Warning: NaN detected, skipping logging this batch.")
                    else:
                        if "grad_norm" not in _logs:
                            logs.update({
                                "epoch": _logs["epoch"],
                                "step": state.global_step,
                                "LR": _logs["LR"],
                                "term_loss": (state.most_recent_output["term_loss"] * 100000.).to(torch.long).to(torch.float32) / 100000.,
                                "pos_loss": int(state.most_recent_output["pos_loss"] * 100000.) / 100000.,
                                "tok_loss": int(state.most_recent_output["tok_loss"] * 100000.) / 100000.,
                            })
                        else:
                            logs.update({
                                "epoch": _logs["epoch"],
                                "step": state.global_step,
                                "LR": _logs["LR"],
                                "grad_norm": int(_logs["grad_norm"] * 1000.) / 1000.,
                                "term_loss": (state.most_recent_output["term_loss"] * 100000.).to(torch.long).to(torch.float32) / 100000.,
                                "pos_loss": int(state.most_recent_output["pos_loss"] * 100000.) / 100000.,
                                "tok_loss": int(state.most_recent_output["tok_loss"] * 100000.) / 100000.,
                            })
                except ValueError:
                    print(state.most_recent_output)

                flog = args.output_dir if args.output_dir else args.overwrite_output_dir
                flog = os.path.join(flog, "log")
                if logs:
                    print(logs, file=open(flog, "a"))

    if args.local_rank == -1:
        print("Error: no node id specified")
        exit()

    # if not os.path.exists("checkpoints/" + checkpoint_name):
    #     os.mkdir("checkpoints/" + checkpoint_name)

    torch._dynamo.config.cache_size_limit = 16

    training_args = TrainingArguments(
        output_dir="checkpoints/" + checkpoint_name + "-" + args.factorization,
        overwrite_output_dir=True,
        per_device_train_batch_size=args.mini_batch_size,
        per_device_eval_batch_size=args.mini_batch_size,
        auto_find_batch_size=False,
        eval_strategy="no",
        do_train=True,
        do_eval=False,
        adam_beta1=0.9,
        gradient_accumulation_steps=(args.batch_size // torch.cuda.device_count() + args.mini_batch_size - 1) // args.mini_batch_size,
        adam_beta2=0.9,
        adam_epsilon=1e-6,
        learning_rate=1e-4,
        logging_strategy="steps",
        logging_steps=10,
        save_strategy="steps",
        save_steps=1000,
        warmup_steps=1000,
        weight_decay=1e-5,
        max_grad_norm=1.,
        num_train_epochs=1,
        # max_steps=600000,
        local_rank=args.local_rank,
        label_names=["input_ids",
                     "permutation_ids",
                     "term_ids",
                     "bidirectional"
                     ],
        tf32=True,
        bf16=True,
        dataloader_pin_memory=True,
        dataloader_num_workers=8,
        dataloader_prefetch_factor=32,
        dataloader_persistent_workers=True,
        deepspeed=args.deepspeed,
    )

    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        tokenizer=tokenizer,
        data_collator=default_data_collator,
        callbacks=[LogCallback()],
    )

    # model.load_state_dict(InsNetForVariationalInsertionLM.from_pretrained("checkpoints/insnet-v2-small-insertion/checkpoint"))
    trainer.train(resume_from_checkpoint=os.path.exists("checkpoints/" + checkpoint_name + "-" + args.factorization+"/checkpoint-10000") or os.path.exists("checkpoints/" + checkpoint_name + "-" + args.factorization+"/checkpoint-1000"))
    embed()
    exit()
if __name__ == '__main__':
    main()

# See PyCharm help at https://www.jetbrains.com/help/pycharm/
