import math
import os
from typing import Any, Dict, Literal, Optional, Union

import datasets
import numpy as np
import torch
import torch.nn.functional as F
from argparse_dataclass import ArgumentParser
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.data import Subset
from transformers import (DataCollatorForSeq2Seq, Seq2SeqTrainer,
                          Seq2SeqTrainingArguments)
from transformers.loss.loss_utils import ForCausalLMLoss
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs

import wandb
from hip_attn.v1_3.attention import HiPAttentionArgs, ScanStage
from hip_research.dataset.openwebtext import OpenWebTextDataset
from hip_research.dataset.pg19 import PG19Dataset
from hip_research.dataset.pg19_long_qa import PG19LongQA
from hip_research.main.jobs.ppl import PplArgs, job_ppl
from hip_research.utils.long_train import Config, get_logger
from torch.utils.checkpoint import checkpoint

log = get_logger()

TRITON_DEBUG = os.getenv("TRITON_DEBUG", "0") == "1"
WANDB_DISABLED = os.getenv("WANDB_MODE", "none") == "disabled"


class Trainer(Seq2SeqTrainer):
    def __init__(
        self,
        config=None,
        model=None,
        args=None,
        data_collator=None,
        train_dataset=None,
        eval_dataset=None,
        tokenizer=None,
        model_init=None,
        compute_metrics=None,
        callbacks=None,
        optimizers=(None, None),
        preprocess_logits_for_metrics=None,
    ):
        log.info("trainer init")
        super().__init__(
            model=model,
            args=args,
            data_collator=data_collator,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            tokenizer=tokenizer,
            model_init=model_init,
            compute_metrics=compute_metrics,
            callbacks=callbacks,
            optimizers=optimizers,
            preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        )

        self.model = model
        self.config = config
        self.pad_token_id = tokenizer.pad_token_id

    def batched_cross_entropy(self, inputs, target, model, train=False):
        with torch.autograd.grad_mode.set_grad_enabled(train):
            outputs = model(
                inputs,
                attention_mask=(inputs.ne(self.pad_token_id)).to(inputs.dtype),
                labels=target,
                use_cache=False,
                token_dropout_p=self.config.token_dropout_p,
            )

            loss, blocks = 0, math.ceil(inputs.size(1) / self.config.long_ce_block_size)
            count = outputs.logits.size(0) * outputs.logits.size(1)                      
            ignore_index = -100
            for b in range(0, inputs.size(1), self.config.long_ce_block_size):           
                if train:
                    logits = torch.utils.checkpoint.checkpoint(
                        model.lm_head, 
                        outputs.logits[:, b : b + self.config.long_ce_block_size]
                    )
                else:
                    logits = model.lm_head(
                        outputs.logits[:, b : b + self.config.long_ce_block_size]
                    )                                                                        

                logits = logits.float()
                labels = target[:, b : b + self.config.long_ce_block_size].to(logits.device)
                                            
                # Shift so that tokens < n predict n                                     
                labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
                shift_labels = labels[..., 1:].contiguous()
                                                                                         
                # Flatten the tokens                                                     
                logits = logits.view(-1, logits.size(-1))                                
                shift_labels = shift_labels.view(-1)
                                                                                         
                # Enable model parallelism
                shift_labels = shift_labels.to(logits.device)
                loss += (1 / count) * nn.functional.cross_entropy(logits, shift_labels, ignore_index=ignore_index, reduction="sum")
            return loss

    def batched_long_cross_entropy(self, inputs, target, model, train=False):
        """
        long ce from https://arxiv.org/pdf/2410.23771 (equation 7 on page 6)
        """

        with torch.autograd.grad_mode.set_grad_enabled(train):
            if not train:
                model.eval()

            if os.environ.get("RECORD_TRAINING_TIME", "0") == "1":
                start = torch.cuda.Event(enable_timing=True)
                end = torch.cuda.Event(enable_timing=True)
                start.record()

            outputs = model(
                inputs,
                attention_mask=(inputs.ne(self.pad_token_id)).to(inputs.dtype),
                labels=target,
                use_cache=False,
                token_dropout_p=self.config.token_dropout_p,
            )
            long_hidden = outputs.logits

            if os.environ.get("RECORD_TRAINING_TIME", "0") == "1":
                end.record()
                torch.cuda.synchronize()
                print(f"fwd step time: {start.elapsed_time(end)=}")

        # this block is always no grad no matter what
        with torch.no_grad():
            model.eval()
            short_outputs = model(
                inputs,
                attention_mask=(inputs.ne(self.pad_token_id)).to(inputs.dtype),
                sliding_window=self.config.long_ce_k,
                use_cache=False,
            )
            if train:
                model.train()

            short_hidden = short_outputs.logits

        with torch.autograd.grad_mode.set_grad_enabled(train):
            loss, blocks = 0, math.ceil(inputs.size(1) / self.config.long_ce_block_size)
            for b in range(0, inputs.size(1), self.config.long_ce_block_size):
                if train:
                    long_logits = torch.utils.checkpoint.checkpoint(
                        model.lm_head, 
                        outputs.logits[:, b : b + self.config.long_ce_block_size]
                    )
                else:
                    long_logits = model.lm_head(
                        long_hidden[:, b : b + self.config.long_ce_block_size]
                    )
                short_logits = model.lm_head(
                    short_hidden[:, b : b + self.config.long_ce_block_size]
                )
                if TRITON_DEBUG:
                    assert not torch.any(torch.isnan(long_logits))
                if TRITON_DEBUG:
                    assert not torch.any(torch.isnan(short_logits))

                # Upcast to float if we need to compute the loss to avoid potential precision issues
                B, S, V = long_logits.size()

                labels = target.to(long_logits.device)[
                    :, b : b + self.config.long_ce_block_size
                ]

                # Shift so that tokens < n predict n.
                # We pad the labels with zero here so that the sizing in nicer in the end
                # the last label will have a weight of 0 in the end anyway.
                labels = nn.functional.pad(labels, (0, 1), value=-100)
                shift_labels = labels[..., 1:].contiguous().to(long_logits.device)

                with torch.no_grad():
                    # slice off last token because there is no label for it
                    long_logp = long_logits.gather(
                        2, shift_labels.unsqueeze(-1).clamp(min=0)
                    ) - torch.logsumexp(long_logits, dim=-1, keepdim=True)
                    short_logp = short_logits.gather(
                        2, shift_labels.unsqueeze(-1).clamp(min=0)
                    ) - torch.logsumexp(short_logits, dim=-1, keepdim=True)

                    if TRITON_DEBUG:
                        assert not torch.any(torch.isnan(long_logp))
                    if TRITON_DEBUG:
                        assert not torch.any(torch.isnan(short_logp))
                    if TRITON_DEBUG:
                        assert not torch.any(torch.isinf(long_logp))
                    if TRITON_DEBUG:
                        assert not torch.any(torch.isinf(short_logp))

                    # likelihood ratio of p(x | long) / p(x | short)
                    # during training, we want to use LongCE which weights
                    # all tokens in backprop. LongPPL is used for eval which 
                    # only gives weight to tokens which exceed a threshold.
                    if train:
                        I = (
                            torch.exp(long_logp - short_logp)
                            .clamp(max=self.config.long_ce_gamma)
                            .squeeze(-1)
                        )
                    else:
                        LSD = torch.exp(long_logp - short_logp).squeeze(-1) > self.config.long_ppl_alpha
                        LCL = torch.exp(long_logp).squeeze(-1) > self.config.long_ppl_beta
                        I = LSD * LCL

                    if TRITON_DEBUG:
                        assert not torch.any(torch.isnan(I))

                    I.masked_fill_(shift_labels == (-100), 0)
                    I = (I / (I.sum(dim=-1, keepdim=True) + 1e-20)).view(-1)
                    if TRITON_DEBUG:
                        assert not torch.any(torch.isnan(I))

                # Flatten the tokens
                long_logits = long_logits.view(-1, V)
                shift_labels = shift_labels.view(-1)

                block_loss = nn.functional.cross_entropy(
                    long_logits.float(),
                    shift_labels,
                    ignore_index=-100,
                    reduction="none",
                )

                if TRITON_DEBUG:
                    assert not torch.any(torch.isnan(block_loss))
                assert (B * blocks) > 0

                loss += (block_loss * I).sum() / (B * blocks)
            return loss

    def compute_loss(
        self,
        model,
        inputs,
        return_outputs=False,
        num_items_in_batch: Optional[int] = None,
    ):
        model.train()

        inputs, target = inputs["input_ids"], inputs["labels"]
        inputs = inputs[..., : self.config.seq_len]
        target = target[..., : self.config.seq_len]

        if self.config.use_long_ce:
            loss = self.batched_long_cross_entropy(
                inputs, target, model, train=True)
        else:
            loss = self.batched_cross_entropy(
                inputs, target, model, train=True)

        if TRITON_DEBUG:
            assert not torch.any(torch.isnan(loss))

        return (loss, outputs) if return_outputs else loss

    def evaluate(
        self,
        eval_dataset=None,
        ignore_keys=None,
        metric_key_prefix: str = "eval",
    ) -> Dict[str, float]:

        if self.config.use_lora:
            _mdl = self.model.base_model.model
        else:
            _mdl = self.model

        hip_train_config = _mdl.hip_train_config

        if self.config.dataset == "pg19":
            delattr(_mdl, "hip_train_config")
            self.model.eval()
            total_tokens = 1024 * 1024
            metrics = {}
            strides = [32768, 65536, 131072]
            for stride in strides:
                ppl = job_ppl(
                    args=PplArgs(
                        stride=stride,
                        dataset="pg19",
                        overwrite=True,
                        name="long_trainer",
                        count=total_tokens // stride,
                    ),
                    model=self.model,
                    tokenizer=self.tokenizer,
                    device=self.model.device,
                    quite=False,
                    quite_step=True,
                )
                metrics[f"{metric_key_prefix}/{stride}/ppl"] = ppl

            setattr(_mdl, "hip_train_config", hip_train_config)
            if wandb.run is not None:
                wandb.log(metrics, step=self.state.global_step)
            return metrics
        else:
            delattr(_mdl, "hip_train_config")
            setattr(_mdl, "no_lm_head", True)
            self.model.eval()
            model = self.model

            ds = eval_dataset
            if ds is None:
                ds = self.eval_dataset

            long_ppl_key = f"{metric_key_prefix}/long-ppl"
            ppl_key = f"{metric_key_prefix}/ppl"
            metrics = {long_ppl_key: 0, ppl_key: 0}
            N = 0

            with torch.no_grad():
                for inputs in ds:
                    N += 1
                    inputs, target = inputs["input_ids"].unsqueeze(0), inputs["labels"].unsqueeze(0)

                    inputs = inputs[..., : self.config.seq_len].cuda()
                    target = target[..., : self.config.seq_len].cuda()
                    print(f"eval: {inputs.size()=}")

                    inputs = nn.functional.pad(inputs, (0, self.config.seq_len - inputs.size(1)), value=self.pad_token_id)
                    target = nn.functional.pad(target, (0, self.config.seq_len - target.size(1)), value=-100)
                    
                    # Long CE =====================================================================================================
                    loss = self.batched_long_cross_entropy(
                        inputs, target, model, train=False)
                    metrics[long_ppl_key] += loss.item()

                    # NORMAL CE =====================================================================================================
                    loss = self.batched_cross_entropy(
                        inputs, target, model, train=False)
                    metrics[ppl_key] += loss.item()

            # change ce into ppl for eval
            metrics[long_ppl_key] = np.exp(metrics[long_ppl_key] / N)
            metrics[ppl_key] = np.exp(metrics[ppl_key] / N)

            setattr(_mdl, "hip_train_config", hip_train_config)
            if wandb.run is not None:
                wandb.log(metrics, step=self.state.global_step)
            return metrics

    def training_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        num_items_in_batch=None,
    ) -> torch.Tensor:
        model.train()
        if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
            self.optimizer.train()

        inputs = self._prepare_inputs(inputs)

        with self.compute_loss_context_manager():
            loss = self.compute_loss(
                model, inputs, num_items_in_batch=num_items_in_batch
            )

        del inputs

        kwargs = {}

        if self.args.n_gpu > 1:
            loss = loss.mean()  # mean() to average on multi-gpu parallel training

        if self.use_apex:
            raise Exception()
            # with amp.scale_loss(loss, self.optimizer) as scaled_loss:
            #     scaled_loss.backward()
        else:
            # Finally we need to normalize the loss for reporting
            if not self.model_accepts_loss_kwargs and self.compute_loss_func is None:
                loss = loss / self.args.gradient_accumulation_steps

            if os.environ.get("RECORD_TRAINING_TIME", "0") == "1":
                start = torch.cuda.Event(enable_timing=True)
                end = torch.cuda.Event(enable_timing=True)
                start.record()

            self.accelerator.backward(loss, **kwargs)

            if os.environ.get("RECORD_TRAINING_TIME", "0") == "1":
                end.record()
                torch.cuda.synchronize()
                print(f"bwd time: {start.elapsed_time(end)=}")

            norms = {}
            for n, p in model.named_parameters():
                if p.requires_grad and (p.grad is not None):
                    norm = p.grad.view(-1).norm().item()
                    assert not math.isnan(norm)
                    assert not math.isinf(norm)
                    norms[f"grad_norm/{n}"] = norm
                    if WANDB_DISABLED:
                        print(n, norm)
            if not WANDB_DISABLED:
                wandb.log(norms, step=self.state.global_step)

            return loss.detach()


def get_hf_dataset(ds):
    def gen():
        for idx in range(len(ds)):
            inputs, targets = ds[idx]
            yield {"input_ids": inputs, "labels": targets}

    return datasets.IterableDataset.from_generator(gen)


def parse_args():
    parser = ArgumentParser(Config)
    train_config = parser.parse_args()
    log.info(train_config)
    return train_config


def get_hip_config(config: Config, layer_idx: int):
    preset_name = os.environ.get("PRESET", "default")
    if preset_name == "default":
        stages = [
            ScanStage(
                stage_block_size_q=64,
                stage_block_stride_q=4,
                stage_chunk_size=256,
                stage_k=None,
                stage_stride=1,
            ),
            ScanStage(
                stage_block_size_q=64,
                stage_block_stride_q=4,
                stage_chunk_size=32,
                stage_k=32768,
                stage_stride=1,
            ),
            ScanStage(
                stage_block_size_q=64,
                stage_block_stride_q=4,
                stage_chunk_size=8,
                stage_k=8192,
                stage_stride=1,
            ),
        ]

        args = HiPAttentionArgs(
            sliding_window_size=1024,
            sink_token_size=256,
            using_extend=True,
            need_apply_rope=True,
            second_stage_k=2048,
            stages=stages,
            model_context_length=65536,
            # scan_extend_backend="relative",
            scan_extend_backend=("streaming" if layer_idx < 3 else "relative"),
            sa_extend_backend="streaming",
            block_sparse_block_size_q=stages[-1].stage_block_size_q,
            enable_hip_tune=not config.disable_hip_tune,
        )
    else:
        raise NotImplementedError(f"{preset_name=} not implemented")

    return args


def init_model(config: Config):
    device = "cuda:0"

    import transformers
    from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS

    from hip_attn.v1_3.models.llama import (LlamaAttention, LlamaConfig,
                                            LlamaForCausalLM)

    ALL_ATTENTION_FUNCTIONS.update({"hip_attention": (lambda x: x)})

    tokenizer = transformers.AutoTokenizer.from_pretrained(config.model)
    model_config = LlamaConfig.from_pretrained(
        config.model,
        attn_implementation="hip_attention",
    )
    model_config.pooler_method = config.pooler_method
    model_config.pooler_config = config.pooler_config
    model = LlamaForCausalLM.from_pretrained(config.model, config=model_config)

    model.hip_train_config = config

    layer_idx = 0
    for m in model.modules():
        if isinstance(m, LlamaAttention):
            m.args = config

            hip_attn_config = get_hip_config(config, layer_idx)
            m.hip_attn_args = hip_attn_config
            m.attention_method = config.method
            m.tree_dense_layers = list(range(len(model.model.layers)))

            layer_idx += 1

    if (config.lora_r > 0) and config.use_lora:
        from peft import LoraConfig, TaskType, get_peft_model
        from peft.tuners.lora.layer import Linear

        from hip_attn.v1_3.models.linear import CustomLoraLinear

        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=config.lora_r,
            lora_alpha=config.lora_r // 2,
            lora_dropout=0.15,
            target_modules=[
                "q_proj",
                "k_proj",
                "v_proj",
                "o_proj",
            ],
            modules_to_save=[
                "token_pooler_fn",
                "output_unpooler_fn",
                "gate_prob_estimater",
                "input_layernorm",
                "post_attention_layernorm",
            ],
        )
        model = get_peft_model(model, peft_config)

        def replace_lora_linear(model):
            for name, module in model.named_modules():
                if isinstance(module, Linear):
                    parent_name = ".".join(name.split(".")[:-1])
                    attr_name = name.split(".")[-1]

                    # Locate the parent module
                    parent_module = model
                    if parent_name:
                        parent_module = dict(model.named_modules())[parent_name]

                    # Replace the module in the parent
                    setattr(
                        parent_module,
                        attr_name,
                        CustomLoraLinear(
                            base_layer=module.base_layer,
                            adapter_name="default",
                            r=config.lora_r,
                            lora_alpha=config.lora_r // 2,
                            lora_dropout=0.15,
                        ),
                    )

        # Apply replacement to the model
        replace_lora_linear(model)

        model.print_trainable_parameters()

    return model, tokenizer


def main(config: Config):
    if os.environ.get("DETECT_ANOMALY", "0") == "1":
        torch.autograd.set_detect_anomaly(True)

    # torch.set_float32_matmul_precision('high')
    torch.backends.cudnn.benchmark = True

    os.environ["WANDB_PROJECT"] = "quick_extend"

    if os.environ.get("CUDA_LAUNCH_BLOCKING", "0") == "1":
        log.info(
            "WARNING: CUDA_LAUNCH_BLOCKING is set to 1, this will slow down the training.",
        )

    filename = f'{config.model.replace("/", "_")}-{config.name}-{config.dataset}-{config.seq_len}-{config.get_hash()}'
    config.model_checkpoint_dir = config.model_checkpoint_dir + "/" + filename

    model, tokenizer = init_model(config=config)
    tokenizer.pad_token = tokenizer.eos_token

    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=config.lr,
        weight_decay=config.weight_decay,
    )

    def warmup(current_step: int):
        if current_step < config.warmup_steps:
            return float(current_step / config.warmup_steps)
        else:
            tick = max(
                0.0,
                float(config.max_steps - current_step)
                / float(max(1, config.max_steps - config.warmup_steps)),
            )
            return (math.cos((1 - tick) * math.pi) + 1) / 2

    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=warmup)

    def train_val_split(dataset):
        train_idx, val_idx = train_test_split(
            list(range(len(dataset))), test_size=config.val_split
        )
        train_data = Subset(dataset, train_idx)
        valid_data = Subset(dataset, val_idx)
        return train_data, valid_data

    test_data = None
    if config.dataset == "owt":
        dataset = OpenWebTextDataset(
            tokenizer=tokenizer, stride=config.seq_len, offset_labels=False
        )
        train_data, valid_data = train_val_split(dataset)
    elif config.dataset == "pg19":
        dataset = PG19Dataset(
            tokenizer=tokenizer, stride=config.seq_len, offset_labels=False
        )
        train_data, valid_data = train_val_split(dataset)
    elif config.dataset == "pg19-longqa":
        path = "/home/anno_0/pg19-hierarchical-qa/"
        train_data = PG19LongQA(tokenizer, path, split="train")
        valid_data = PG19LongQA(tokenizer, path, split="validation")
        test_data = PG19LongQA(tokenizer, path, split="test")
        test_data = get_hf_dataset(test_data)
    else:
        raise ValueError(f"Unknown dataset {config.dataset}")

    trainer_config = Seq2SeqTrainingArguments(
        logging_steps=1,
        bf16=True,
        bf16_full_eval=True,
        output_dir=config.model_checkpoint_dir,
        gradient_accumulation_steps=config.accumulation_steps,
        max_steps=config.max_steps,
        # report_to=["wandb"],
        gradient_checkpointing=True,
        save_total_limit=config.save_total_limit,
        save_steps=config.save_steps,
        per_device_train_batch_size=config.batch_size,
        per_device_eval_batch_size=config.batch_size,
        learning_rate=config.lr,
        ignore_data_skip=True,
        warmup_steps=config.warmup_steps,
        eval_strategy="steps",
        eval_steps=config.eval_steps,
        run_name=config.run_name,
        # max_grad_norm=None,
        # local_rank=config.local_rank,
    )

    trainer = Trainer(
        config=config,
        model=model,
        args=trainer_config,
        data_collator=DataCollatorForSeq2Seq(
            tokenizer=tokenizer,
            padding="longest",
            pad_to_multiple_of=config.seq_len,
        ),
        train_dataset=get_hf_dataset(train_data),
        eval_dataset=get_hf_dataset(valid_data),
        tokenizer=tokenizer,
        optimizers=(optimizer, lr_scheduler),
    )

    if os.environ.get("DO_EVAL", "0") == "1":
        assert test_data is not None
        trainer.model = trainer.model.cuda()
        metrics = trainer.evaluate(valid_data, metric_key_prefix="test")
        print(f"test metrics: {metrics=}")
    else:
        trainer.train(resume_from_checkpoint=config.checkpoint)
        if test_data is not None:
            metrics = trainer.evaluate(test_data, metric_key_prefix="test")
            print(f"test metrics: {metrics=}")

def run():
    seed()
    main(parse_args())


def seed(seed=42):
    import random

    import numpy as np

    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    os.environ["PYTHONHASHSEED"] = str(seed)


if __name__ == "__main__":
    run()
