import os
from dataclasses import dataclass
from typing import Dict, Union, Any, Optional, List, Tuple

import wandb
import datasets
import numpy as np
import torch
import torch.autograd
import torch.onnx
import torch.utils.checkpoint
from argparse_dataclass import ArgumentParser
from sklearn.model_selection import train_test_split
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Subset
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq

from quick_extend.dataset.openwebtext import OpenWebTextDataset
from quick_extend.dataset.booksum import BookSumDataset
from quick_extend.dataset.redpajama import RedPajamaDataset
from quick_extend.models.load_model import ModelConfig, load_model


@dataclass
class TrainConfig(ModelConfig):
    dataset: str = "owt"  # or "bs", "rp"
    lr: float = 5e-5
    batch_size: int = 1
    accumulation_steps: int = 2
    save_steps: int = 100
    seq_len: int = 4096
    max_steps: int = 1000000
    model_checkpoint_dir: str = "./saves/checkpoints"
    load_from_checkpoint: str = None
    warmup_steps: int = 5
    name: str = 'default'
    val_split: float = 0.01
    disable_mem_efficient_ce: bool = False
    enable_tempering: bool = False
    zero_stage: int = 1


def compute_a_loss(logits_1, log_pr1_topk, topk_indices, logits_2, ALPHA: float):
    log_pr2 = torch.log_softmax(logits_2, dim=-1)  # (bsz, len, TOP_K, vocab_size)
    log_pr12 = log_pr2 + log_pr1_topk[..., None]  # (bsz, len, TOP_K, vocab_size)

    # Apply temperature
    a_log_pr12 = log_pr12 * ALPHA  # (bsz, len, TOP_K, vocab_size)
    a_log_pr12 = a_log_pr12 - torch.logsumexp(a_log_pr12, dim=[-2, -1], keepdim=True)  # Re-normalize

    # Marginalize
    a_log_pr1 = torch.sum(a_log_pr12.exp(), dim=-1)  # (bsz, len, TOP_K)
    a_log_pr1 = torch.scatter(
        torch.zeros_like(logits_1),
        dim=-1, index=topk_indices, src=a_log_pr1,
    )  # (bsz, len, vocab_size)

    a_logits_1 = ALPHA * logits_1
    a_logits_1 = a_logits_1 - torch.logsumexp(a_logits_1, dim=-1, keepdim=True)  # Re-normalize
    a_loss = F.kl_div(
        input=a_logits_1.view(-1, logits_1.size(-1)),
        target=a_log_pr1.view(-1, logits_1.size(-1)),
        log_target=False,
        reduction='batchmean',
    )
    return a_loss


class Trainer(Seq2SeqTrainer):
    def __init__(
            self,
            config=None,
            model=None,
            args=None,
            data_collator=None,
            train_dataset=None,
            eval_dataset=None,
            tokenizer=None,
            model_init=None,
            compute_metrics=None,
            callbacks=None,
            optimizers=(None, None),
            preprocess_logits_for_metrics=None,
    ):
        super().__init__(
            model, args, data_collator, train_dataset, eval_dataset,
            tokenizer, model_init, compute_metrics, callbacks,
            optimizers, preprocess_logits_for_metrics
        )

        self.model = model

        self.config: "TrainConfig" = config
        self.pad_token_id = tokenizer.pad_token_id

    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        result = super().training_step(model, inputs)
        return result

    def tempering_step(self, model, inputs, target, return_outputs=False):
        outputs = model(
            inputs,
            labels=target,
            attention_mask=(inputs.ne(self.pad_token_id)).to(inputs.dtype),
            use_efficient_ce=False,
            use_cache=True,
        )
        loss_model = outputs.loss
        log_pr1 = torch.log_softmax(outputs.logits, dim=-1)  # (bsz, len, vocab_size)

        TOP_K = 32
        ALPHA = 10
        COEFF = 0.1
        log_pr1_topk, topk_indices = torch.topk(log_pr1, k=TOP_K, dim=-1)  # (bsz, len, TOP_K)

        outputs_2 = self.model(
            topk_indices,
            cache_position=torch.arange(1, 1 + topk_indices.size(1), device=topk_indices.device),
            past_key_values=outputs.past_key_values,
            use_cache=False,  # Do not return past_key_values
        )

        a_loss = torch.zeros((), device=log_pr1.device)
        cnt = 0
        for b in range(outputs.logits.shape[0]):
            for l in range(0, outputs.logits.shape[1], 256):
                r = (slice(b, b+1), slice(l, l+256))
                a_loss_part = torch.utils.checkpoint.checkpoint(
                    compute_a_loss,
                    outputs.logits[r], log_pr1_topk[r], topk_indices[r], outputs_2.logits[r], ALPHA,
                )
                a_loss = a_loss + a_loss_part
                cnt += 1

        a_loss = a_loss / cnt

        loss = loss_model + COEFF * a_loss

        log_dict = dict()
        log_dict["training/loss_model"] = loss_model.item()
        log_dict["training/a_loss"] = a_loss.item()
        log_dict["training/loss"] = loss.item()
        self.log(log_dict)

        return (loss, outputs) if return_outputs else loss

    def compute_loss(self, model, inputs, return_outputs=False):
        model.train()

        inputs, target = inputs['input_ids'], inputs['labels']
        inputs = inputs[..., :self.config.seq_len]
        target = target[..., :self.config.seq_len]

        if self.config.enable_tempering:
            return self.tempering_step(model, inputs, target, return_outputs)

        outputs = model(
            inputs,
            labels=target,
            attention_mask=(inputs.ne(self.pad_token_id)).to(inputs.dtype),
            use_efficient_ce=not self.config.disable_mem_efficient_ce,
            use_cache=False,
        )
        loss_model = outputs.loss

        loss = loss_model

        log_dict = dict()
        log_dict["training/loss_model"] = loss_model.item()
        log_dict["training/loss"] = loss.item()
        self.log(log_dict)

        return (loss, outputs) if return_outputs else loss


def get_hf_dataset(ds):
    def gen():
        for idx in range(len(ds)):
            inputs, targets = ds[idx]
            yield {'input_ids': inputs, 'labels': targets}

    return datasets.IterableDataset.from_generator(gen)


def parse_args():
    parser = ArgumentParser(TrainConfig)
    train_config = parser.parse_args()
    print(train_config)
    return train_config


def get_deepspeed_config(config: TrainConfig):
    if not config.using_deepspeed:
        return None

    stage_1_config = {
        "stage": 1,
        "reduce_bucket_size": 5e8
    }

    stage_2_config = {
        "stage": 2,
        "contiguous_gradients": True,
        "overlap_comm": True,
        "reduce_scatter": True,
        "reduce_bucket_size": 5e8,
        "allgather_bucket_size": 5e8
    }

    stage_3_config = {
        "stage": 3,
        "offload_optimizer": {
            "device": "cpu",
            "pin_memory": True
        },
        "offload_param": {
            "device": "cpu",
            "pin_memory": True
        },
        "overlap_comm": True,
        "contiguous_gradients": True,
        "sub_group_size": 1e8,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": 15099494,
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e8,
        "stage3_max_reuse_distance": 1e8,
        "stage3_gather_16bit_weights_on_model_save": True,
        "zero_hpz_partition_size": torch.cuda.device_count(),
    }

    if config.zero_stage == 1:
        zero_stage = stage_1_config
    elif config.zero_stage == 2:
        zero_stage = stage_2_config
    elif config.zero_stage == 3:
        zero_stage = stage_3_config
    else:
        raise ValueError(f'Unknown zero stage {config.zero_stage}')

    return {
        "bf16": {
            "enabled": True
        },
        "zero_optimization": zero_stage,
        "activation_checkpointing": {
            "partition_activations": True,
            "cpu_checkpointing": True,
            "contiguous_memory_optimization": False,
            "number_checkpoints": None,
            "synchronize_checkpoint_boundary": False,
            "profile": False
        },
        "train_micro_batch_size_per_gpu": config.batch_size,
        "gradient_accumulation_steps": config.accumulation_steps,
        "gradient_clipping": 1.0,
    }


def main(config: TrainConfig):
    # torch.autograd.set_detect_anomaly(True)
    # torch.set_float32_matmul_precision('high')
    torch.backends.cudnn.benchmark = True

    os.environ["WANDB_PROJECT"] = "quick_extend"

    if os.environ.get("CUDA_LAUNCH_BLOCKING", "0") == "1":
        print("WARNING: CUDA_LAUNCH_BLOCKING is set to 1, this will slow down the training.", flush=True)

    filename = f'{config.model.replace("/", "_")}-{config.name}-{config.mix_mode}-{config.dataset}-{config.seq_len}-{config.hip_top_k_elems}-{config.hip_group_size_q}'
    config.model_checkpoint_dir = config.model_checkpoint_dir + '/' + filename

    ds_config = get_deepspeed_config(config)
    model, tokenizer = load_model(model_config=config, ds_config=ds_config)

    if config.dataset == 'owt':
        dataset = OpenWebTextDataset(tokenizer=tokenizer, stride=config.seq_len)
    elif config.dataset == 'bs':
        dataset = BookSumDataset(tokenizer=tokenizer, need_tokenization=True, max_seq_len=config.seq_len)
    elif config.dataset == 'rp':
        dataset = RedPajamaDataset(tokenizer=tokenizer, stride=config.seq_len)
    else:
        raise ValueError(f'Unknown dataset {config.dataset}')

    train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=config.val_split)
    train_data = Subset(dataset, train_idx)
    valid_data = Subset(dataset, val_idx)

    trainer_config = Seq2SeqTrainingArguments(
        logging_steps=1,
        bf16=True,
        output_dir=config.model_checkpoint_dir,
        gradient_accumulation_steps=config.accumulation_steps,
        max_steps=config.max_steps,
        report_to=["wandb"],
        gradient_checkpointing=True,
        save_total_limit=5,
        save_steps=config.save_steps,
        per_device_train_batch_size=config.batch_size,
        per_device_eval_batch_size=config.batch_size,
        learning_rate=config.lr,
        ignore_data_skip=True,
        warmup_steps=config.warmup_steps,
        local_rank=config.local_rank,
        deepspeed=ds_config,
    )

    trainer = Trainer(
        config=config,
        model=model,
        args=trainer_config,
        train_dataset=get_hf_dataset(train_data),
        eval_dataset=get_hf_dataset(valid_data),
        tokenizer=tokenizer,
        data_collator=DataCollatorForSeq2Seq(
            tokenizer=tokenizer,
            padding='longest',
            pad_to_multiple_of=config.seq_len,
        ),
    )

    trainer.train(resume_from_checkpoint=config.load_from_checkpoint)


def run():
    seed()
    main(parse_args())


def seed(seed=42):
    import numpy as np
    import random
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if use multi-GPU
    os.environ["PYTHONHASHSEED"] = str(seed)


if __name__ == "__main__":
    run()
