import os
import datetime
import warnings
import json
import random
import wandb

from dataclasses import dataclass, field
from functools import partial
from typing import Optional

import torch
import deepspeed
import transformers
# from torch.utils.data import Dataset
from transformers import Trainer, DataCollatorForLanguageModeling
from transformers.models.qwen3 import Qwen3ForCausalLM, Qwen3Config
from transformers.models.llama import LlamaForCausalLM, LlamaConfig
from transformers.trainer_utils import get_last_checkpoint
from deepspeed.accelerator import get_accelerator
from torch.distributed import barrier
from dataset import load_dataset

from rnsa.qwen3 import RNSAQwen3ForCausalLM, RNSAQwen3Config
from rnsa.qwen2 import RNSAQwen2ForCausalLM, RNSAQwen2Config
from rnsa.llama import RNSALlamaForCausalLM, RNSALlamaConfig
from rnsa.phi3 import RNSAPhi3ForCausalLM, RNSAPhi3Config

warnings.simplefilter(action='ignore', category=FutureWarning)


def ds_param_count(model, trainable_only=False):
    params = (p for p in model.parameters() if (p.requires_grad or not trainable_only))
    return sum(getattr(p, "ds_numel", p.numel()) for p in params)

def truncate(s, max_length=127):
    """Truncate a string to a maximum length, ensuring it does not exceed the limit."""
    s = s.replace("-", "_").replace("/", "_").replace(",", "_")
    if len(s) > max_length:
        return s[:max_length - 3] + "..."
    return s


class RNSATrainer(Trainer):
    def __init__(
            self,
            base_loss='ntp',
            *args,
            **kwargs
        ):
        self.base_loss = base_loss
        self._tokenizer = kwargs.pop("tokenizer", None)
        super().__init__(*args, **kwargs)

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        if 'logits_distil' in self.base_loss:
            with torch.no_grad():
                ori_outputs = model(
                    **inputs,
                    vanilla_forward=True,
                )
            inputs["base_logits"] = ori_outputs.logits
            torch.cuda.empty_cache()
            get_accelerator().empty_cache()

        outputs = model(**inputs)

        loss = outputs.get("loss")
        logs = {"total_loss": loss.item()}
        if getattr(outputs, "forget_loss", None) is not None:
            logs["forget_loss"] = outputs.forget_loss.item()
        if getattr(outputs, "base_loss", None) is not None:
            logs["base_loss"] = outputs.base_loss.item()
        self.log(logs)

        return (loss, outputs) if return_outputs else  loss

    def training_step(self, *args, **kwargs):
        out = super().training_step(*args, **kwargs)
        return out


@dataclass
class ModelArguments:
    base_model: str = field(default="meta-llama/Meta-Llama-3.1-8B")
    model_name_or_path: Optional[str] = field(default=None, metadata={"help": "The local model path if any."})
    forget_gate: str = field(
        default="fg4",
        metadata={"help": "The forget gate implementation to use. Options: 'fg3', 'fg2', 'fg1'."},
    )
    base_loss: str = field(
        default="ntp",
        metadata={"help": "The base loss to use. Options: 'ntp', 'fw_logits_distil', 'rv_logits_distil'."},
    )
    attn_impl: str = field(
        default="rnsa_flex",
        metadata={"help": "The attention implementation to use. Options: 'memeffi', 'rnsa_memeffi'."},
    )
    memory_size: float = field(
        default=1024,
        metadata={"help": "The memory size of the model. If < 1, it represents the fraction of the sequence length."},
    )
    forget_weight: float = field(
        default=1.0,
        metadata={"help": "The forget weight of the model."},
    )
    forget_gate_bias_init: float = field(
        default=0.0,
        metadata={"help": "The forget gate bias init of the model."},
    )
    use_cache: bool = field(
        default=False,
        metadata={"help": "Whether to use cache in the model."},
    )
    skip_layers: Optional[int] = field(
        default=0,
        metadata={"help": "Number of layers to skip for compression. If set to 0, no layers are skipped."},
    )
    fg_dropout: float = field(
        default=0.0,
        metadata={"help": "Dropout rate for forget gate."},
    )


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    optim: str = field(default="adamw_torch")
    training_max_length: int = field(
        default=None,
        metadata={"help": "Maximum sequence length in training."},
    )
    trainable_params: str = field(
        default="self_attn.f_proj|self_attn.forget_gate|self_attn.v_proj|self_attn.q_proj|self_attn.k_proj",
        metadata={"help": "compressor trainable parameters."},
    )
    resume_from_checkpoint: Optional[str] = field(
        default=None,
        metadata={
            "help": "Path to a checkpoint to resume training from.",
        },
    )
    dataset_name: Optional[str] = field(
        default="./cache_dir/qwq-sftopenr1-newtemplate",
        metadata={"help": "The name of the dataset to use."},
    )
    max_samples: Optional[int] = field(
        default=-1,
        metadata={"help": "For debugging purposes, truncate the number of samples."},
    )
    dataset_path: Optional[str] = field(
        default=None,
        metadata={"help": "The path to the dataset if any."},
    )
    save_entire_model: bool = field(
        default=False,
        metadata={"help": "Save entire model."},
    )
    overwrite_output_dir: bool = field(
        default=False,
        metadata={"help": "Overwrite the output directory."},
    )
    gradient_checkpointing: bool = field(
        default=False,
        metadata={"help": "Whether to use gradient checkpointing."},
    )
    logit_block_size: int = field(
        default=-1,
        metadata={"help": "The block size for logit computation."},
    )


def update_config(config, model_args, training_args):
    config.base_model = model_args.base_model
    config.forget_gate = model_args.forget_gate
    config.memory_size = model_args.memory_size
    config.forget_weight = model_args.forget_weight
    config.forget_gate_bias_init = model_args.forget_gate_bias_init
    config.trainable_params = training_args.trainable_params
    config.attn_impl = model_args.attn_impl
    config.use_cache = model_args.use_cache
    config.base_loss = model_args.base_loss
    config.skip_layers = model_args.skip_layers
    config.fg_dropout = model_args.fg_dropout
    config.logit_block_size = training_args.logit_block_size

    if training_args.training_max_length is not None:
        config.max_seq_len = training_args.training_max_length
    return config


def train():
    deepspeed.init_distributed(dist_backend="nccl", init_method="env://", timeout=datetime.timedelta(minutes=120))

    parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
    model_args, training_args = parser.parse_args_into_dataclasses()

    if os.path.exists(os.path.join(training_args.output_dir, "rnsa_weights.pth")) and not training_args.overwrite_output_dir:
        print(f"Attn gate weights already exist at {os.path.join(training_args.output_dir, "attn_gate_weights.pth")}, skip training.")
        return

    if model_args.model_name_or_path is None:
        model_args.model_name_or_path = model_args.base_model

    if "qwen3" in model_args.model_name_or_path.lower():
        print("Using Qwen3 model")
        model_cls = RNSAQwen3ForCausalLM
        config_cls = RNSAQwen3Config
    elif "qwen" in model_args.model_name_or_path.lower():
        print("Using Qwen2 model")
        model_cls = RNSAQwen2ForCausalLM
        config_cls = RNSAQwen2Config
    elif "llama" in model_args.model_name_or_path.lower():
        print("Using Llama model")
        model_cls = RNSALlamaForCausalLM
        config_cls = RNSALlamaConfig
    elif "phi-3" in model_args.model_name_or_path.lower() or "phi-4" in model_args.model_name_or_path.lower():
        print("Using Phi model")
        model_cls = RNSAPhi3ForCausalLM
        config_cls = RNSAPhi3Config
    else:
        raise ValueError("Model not supported. Current only support qwen2, qwen3, and llama model.")
    
    config = config_cls.from_pretrained(
        model_args.model_name_or_path,
    )
    config = update_config(config, model_args, training_args)

    model = model_cls.from_pretrained(
        model_args.model_name_or_path,
        load_rnsa_weights=False,
        config=config,
        torch_dtype=torch.bfloat16,
    )
    model.enable_input_require_grads()
    print(model)
    print("Using model:", model_args.model_name_or_path)
    print("Config:", model.config)
    print("tokenier name:", model_args.model_name_or_path, 
          "model name:", model_args.base_model, "training max length:", training_args.training_max_length)

    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        padding_side="right",
        model_max_length=training_args.training_max_length,
        trust_remote_code=True,
        use_fast=True,
    )

    # Some chat template will remove the tokens between <think> and </think>, 
    # so we need modify the chat template to keep the tokens for training.
    with open("chat_template/templates.json", "r") as f:
        chat_template = json.load(f)

    if model_args.base_model in chat_template.keys():
        # print("Using modified chat template:", model_args.base_model)
        tokenizer.chat_template = chat_template[model_args.base_model]

    total_num_params = ds_param_count(model, trainable_only=False)
    total_trainable_params = 0
    trainable_params = training_args.trainable_params.split("|")
    for n, p in model.named_parameters():
        if any(trainable_param in n for trainable_param in trainable_params):
            p.requires_grad = True
            # compute the number of trainable parameters
            num_params = getattr(p, "ds_numel", p.numel())
            total_trainable_params += num_params
        else:
            p.requires_grad = False

    print(f"Total trainable parameters: {total_trainable_params} ({total_trainable_params / total_num_params * 100:.2f}%)")

    rank = int(os.environ.get('RANK', -1))
    if rank > 0:
        barrier()

    dataset = load_dataset(
        training_args=training_args,
        tokenizer=tokenizer,
    )

    print("Dataset size:", len(dataset))

    if rank == 0:
        barrier()

    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    print("Output directory:", training_args.output_dir)

    trainer = RNSATrainer(
        model=model,
        base_loss=model_args.base_loss,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=dataset,
        eval_dataset=None,
        data_collator=data_collator,
    )

    if training_args.resume_from_checkpoint == "None":
        training_args.resume_from_checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if training_args.resume_from_checkpoint == 'auto' and last_checkpoint is not None:
            print(f"Found checkpoint {last_checkpoint}. Resuming training.")
            training_args.resume_from_checkpoint = last_checkpoint

        if not os.path.isdir(training_args.resume_from_checkpoint):
            raise ValueError(f"Checkpoint {training_args.resume_from_checkpoint} does not exist.")
        print(f"Resuming from checkpoint: {training_args.resume_from_checkpoint}")

    # torch.cuda.memory._record_memory_history(max_entries=100000)

    trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
    # torch.cuda.memory._dump_snapshot("ms.pickle")
    # torch.cuda.memory._record_memory_history(enabled=None)

    print("Saving model...")
    trainer.save_state()
    if training_args.save_entire_model:
        trainer.save_model(output_dir=training_args.output_dir)
    elif rank == 0:
        if hasattr(trainer.model, 'module'):
            state_dict = trainer.model.module.state_dict()
        else:
            state_dict = trainer.model.state_dict()

        model.config.save_pretrained(training_args.output_dir)
        trainable_params = model.config.trainable_params.split("|")
        attn_gate_state_dict = {
            k: v for k, v in state_dict.items() if any(trainable_param in k for trainable_param in trainable_params)
        }
        path = os.path.join(training_args.output_dir, "rnsa_weights.pth")
        torch.save(attn_gate_state_dict, path)
        print(f"Saved attention gate weights to {path}")
        # submit to wandb, check wandb mode is online
        if wandb.run is not None:
            artifact = wandb.Artifact(name=truncate(training_args.run_name), type="model")
            artifact.add_file(path)
            artifact.add_file(os.path.join(training_args.output_dir, "config.json"))
            wandb.log_artifact(artifact)


if __name__ == "__main__":
    # set random seed for reproducibility
    # torch.autograd.set_detect_anomaly(True)
    # set print options
    torch.set_printoptions(precision=10, sci_mode=False)
    random.seed(42)
    torch.manual_seed(42)
    transformers.set_seed(42)
    train()
