import os
from pathlib import Path
from typing import Dict, Union, Any

import datasets
import torch
import torch.autograd
import torch.onnx
import torch.utils.checkpoint
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Subset
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq

from hip.dataset.alpaca import AlpacaDataset
from hip.dataset.booksum import BookSumDataset
from hip.dataset.labdataset import LabDataset
from hip.dataset.lmsys import LmsysChatDataset
from hip.dataset.openwebtext import OpenWebTextDataset
from hip.trainer.common import TrainConfig, parse_args, load_model, load_tokenizer
from hip.utils import seed


# torch.autograd.set_detect_anomaly(True)
torch.set_float32_matmul_precision('high')


class LabDataModule:
    def __init__(
            self,
            config: TrainConfig,
            num_workers: int = 0,
            data_dir: Path = "data",
            download: bool = True,
            train_size: float = 0.9,
    ):
        super().__init__()
        self.config = config
        self.data_dir = data_dir
        self.block_size = config.seq_len
        self.download = download
        self.num_workers = num_workers
        self.train_size = train_size
        self.dataset = None
        self.tokenizer = load_tokenizer(config.model)
        self.bsize = config.batch_size

    def get_dataset(self):
        if self.config.dataset in ['wikitext2', 'wikitext103']:
            dataset = LabDataset(
                data_dir=self.data_dir,
                block_size=self.block_size,
                download=self.download,
                tokenizer=self.tokenizer,
                dataset=self.config.dataset,
            )
        elif self.config.dataset in ['alpaca']:
            dataset = AlpacaDataset(
                tokenizer=self.tokenizer,
            )
        elif self.config.dataset in ['booksum']:
            dataset = BookSumDataset(
                tokenizer=self.tokenizer,
            )
        elif self.config.dataset in ['openwebtext']:
            dataset = OpenWebTextDataset(
                tokenizer=self.tokenizer,
                stride=self.block_size,
            )
        elif self.config.dataset == 'lmsys':
            dataset = LmsysChatDataset(
                tokenizer=self.tokenizer,
            )
        else:
            raise Exception()
        return dataset

    def prepare_data(self):
        self.get_dataset()

    def setup(self, stage: str):
        self.dataset = self.get_dataset()
        if self.config.dataset in ['wikitext2', 'wikitext103']:
            if stage == "fit" or stage is None:
                test_size = min(100, len(self.dataset) * (1 - self.train_size))
                train_size = int(len(self.dataset) - test_size)
                self.train_data, self.val_data = random_split(self.dataset, lengths=[train_size, test_size])
            if stage == "test" or stage is None:
                self.test_data = self.val_data
        elif self.config.dataset in ['booksum', 'alpaca', 'openwebtext', 'lmsys']:
            if stage == "fit" or stage is None:
                def train_val_dataset(dataset, val_split=0.05):
                    train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=val_split)
                    train = Subset(dataset, train_idx)
                    valid = Subset(dataset, val_idx)
                    return train, valid

                self.train_data, self.val_data = train_val_dataset(self.dataset)
            if stage == "test" or stage is None:
                self.test_data = self.val_data
        else:
            raise Exception()

    def train_dataloader(self):
        return DataLoader(self.train_data, num_workers=self.num_workers, batch_size=self.bsize)

    def val_dataloader(self):
        return DataLoader(self.val_data, num_workers=self.num_workers, batch_size=self.bsize)

    def test_dataloader(self):
        return DataLoader(self.test_data, num_workers=self.num_workers, batch_size=self.bsize)


def get_hf_dataset(ds):
    def gen():
        for idx in range(len(ds)):
            inputs, targets = ds[idx]
            yield {'input_ids': inputs, 'labels': targets}

    return datasets.IterableDataset.from_generator(gen)


class Trainer(Seq2SeqTrainer):
    def __init__(
            self,
            config=None,
            model=None,
            teacher=None,
            args=None,
            data_collator=None,
            train_dataset=None,
            eval_dataset=None,
            tokenizer=None,
            model_init=None,
            compute_metrics=None,
            callbacks=None,
            optimizers=(None, None),
            preprocess_logits_for_metrics=None,
    ):
        super().__init__(
            model, args, data_collator, train_dataset, eval_dataset,
            tokenizer, model_init, compute_metrics, callbacks,
            optimizers, preprocess_logits_for_metrics
        )

        self.model = model
        self.teacher = teacher

        self.validation_preds = []
        self.validation_targets = []

        self.config = config
        self.pad_token_id = tokenizer.pad_token_id

    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        result = super().training_step(model, inputs)
        return result

    def compute_loss(self, model, inputs, return_outputs=False):
        if self.teacher is not None:
            self.teacher.eval()
        self.model.train()

        inputs, target = inputs['input_ids'], inputs['labels']
        inputs = inputs[..., :self.config.seq_len]
        target = target[..., :self.config.seq_len]

        if not self.config.disable_kd:
            with torch.no_grad():  # , torch.autocast('cuda', torch.bfloat16):
                output_teacher = self.teacher(inputs, output_hidden_states=not self.config.disable_kd)
        # with torch.autocast('cuda', torch.bfloat16):
        output = self.model(
            inputs,
            attention_mask=(inputs.ne(self.pad_token_id)).to(inputs.dtype),
            #labels=target,
            output_hidden_states=not self.config.disable_kd,
            output_attn_sparsity_loss=self.config.sparsity_reg != 0,
        )
        logits = output.logits

        loss_model = torch.nn.functional.cross_entropy(
            logits.reshape(-1, logits.shape[-1]).to(torch.float32),
            target.reshape(-1)
        )

        loss_kd_hidden = 0
        loss_kd_logits = 0
        if not self.config.disable_kd:
            for teacher_layer, student_layer in zip(output_teacher.hidden_states, output.hidden_states):
                loss_kd_hidden += torch.nn.functional.mse_loss(student_layer.to(torch.float32),
                                                               teacher_layer.to(torch.float32))
            loss_kd_hidden = loss_kd_hidden / len(output_teacher.hidden_states)

            loss_kd_logits = torch.nn.functional.kl_div(
                output.logits.view(-1, logits.shape[-1]).to(torch.float32).log_softmax(-1),
                output_teacher.logits.view(-1, logits.shape[-1]).to(torch.float32).softmax(-1),
                reduction='batchmean',
            )

        if not self.config.disable_kd:
            loss = loss_model * 0.1 + (loss_kd_hidden + loss_kd_logits) * 2.5
        else:
            loss = loss_model

        sparsity_loss = None
        if self.config.sparsity_reg != 0:
            sparsity_loss = sum(
                layer_sparsity.mean()
                for layer_sparsity in output.attn_sparsity_loss
                if layer_sparsity is not None
            ) / len(output.attn_sparsity_loss)
            loss = loss + self.config.sparsity_reg * sparsity_loss

        log_dict = dict()
        log_dict["training/loss_model"] = loss_model.item()
        if not self.config.disable_kd:
            if loss_kd_hidden > 0:
                log_dict["training/loss_kd_hidden"] = loss_kd_hidden.item()
            if loss_kd_logits > 0:
                log_dict["training/loss_kd_logits"] = loss_kd_logits.item()
        if sparsity_loss is not None:
            log_dict["training/sparsity_loss"] = sparsity_loss.item()
        log_dict["training/loss"] = loss.item()
        self.log(log_dict)

        return loss

    def validation_step(self, batch, batch_idx):
        self.model.eval()

        inputs, target = batch
        inputs = inputs[..., :self.config.seq_len]
        target = target[..., :self.config.seq_len]

        with torch.no_grad():  # , torch.autocast('cuda', torch.bfloat16):
            # print('asdfasdf', inputs.shape, target.shape, flush=True)
            output = self.model(
                inputs,
                attention_mask=(inputs != self.pad_token_id).to(inputs.dtype),
                labels=target,
            ).logits
            loss = torch.nn.functional.cross_entropy(
                output.view(-1, output.shape[-1]),
                target.view(-1)
            )
        self.log({"val/loss": loss.item()})

        self.validation_preds.append(output.cpu())
        self.validation_targets.append(target.cpu())

    def on_validation_epoch_end(self):
        from torchmetrics.text.perplexity import Perplexity
        with torch.no_grad():
            device = 'cpu'
            if self.config.using_fsdp:
                device = 'cuda'
            calculator = Perplexity(ignore_index=-100).to(device)
            for preds, target in zip(self.validation_preds, self.validation_targets):
                calculator.update(preds.to(device), target.to(device))
            ppl = calculator.compute()
        ppl = ppl.item()
        print('val/ppl', ppl)
        self.log({"val/ppl": ppl})

        self.validation_preds.clear()
        self.validation_targets.clear()


def main(config: TrainConfig):
    os.environ["WANDB_PROJECT"] = "hip-attention"

    os.makedirs('./saves/dev/wandb', exist_ok=True)
    os.makedirs('./saves/dev/checkpoint', exist_ok=True)

    if config.method == 'hip':
        filename = f'{config.model}-{config.dataset}-{config.name}-{config.seq_len}-bq{config.block_size_q}-bk{config.block_size_k}-k{config.k}-sp{config.sparsity_reg}'
    elif config.method == 'none':
        filename = f'{config.model}-{config.dataset}-{config.name}-{config.seq_len}'
    elif config.method == 'reformer':
        filename = f'{config.model}-{config.method}-{config.name}-{config.dataset}-{config.seq_len}-k{config.k}'
    elif config.method == 'performer':
        filename = f'{config.model}-{config.method}-{config.name}-{config.dataset}-{config.seq_len}'
    else:
        raise Exception()

    config.model_checkpoint_dir = config.model_checkpoint_dir + '/' + filename

    model = load_model(train_config=config, method=config.method)
    if not config.disable_kd:
        teacher = load_model(train_config=config, method='none', is_teacher=True)
    else:
        teacher = None

    datamodule = LabDataModule(config=config)
    datamodule.setup("fit")

    ds_config = {
        "bf16": {
            "enabled": True
        },
        "zero_optimization": {
            "stage": 3,
            "offload_optimizer": {
                "device": "cpu",
                "pin_memory": True
            },
            "offload_param": {
                "device": "cpu",
                "pin_memory": True
            },
            "overlap_comm": True,
            "contiguous_gradients": True,
            "sub_group_size": 1e8,
            "reduce_bucket_size": "auto",
            "stage3_prefetch_bucket_size": "auto",
            "stage3_param_persistence_threshold": "auto",
            "stage3_max_live_parameters": 1e8,
            "stage3_max_reuse_distance": 1e8,
            "stage3_gather_16bit_weights_on_model_save": True,
            "zero_hpz_partition_size": torch.cuda.device_count(),
        },
        "activation_checkpointing": {
            "partition_activations": True,
            "cpu_checkpointing": True,
            "contiguous_memory_optimization": False,
            "number_checkpoints": None,
            "synchronize_checkpoint_boundary": False,
            "profile": False
        },
        "train_micro_batch_size_per_gpu": config.batch_size,
        "gradient_accumulation_steps": config.accumulation_steps,
        "gradient_clipping": 1.0,
    }

    trainer_config = Seq2SeqTrainingArguments(
        logging_steps=1,
        bf16=True,
        output_dir=config.model_checkpoint_dir,
        gradient_accumulation_steps=config.accumulation_steps,
        max_steps=config.max_steps,
        report_to=["wandb"],
        gradient_checkpointing=True,
        save_total_limit=3,
        save_steps=config.save_steps,
        per_device_train_batch_size=config.batch_size,
        per_device_eval_batch_size=config.batch_size,
        learning_rate=config.lr,
        ignore_data_skip=True,
        warmup_steps=config.warmup_steps,
        local_rank=config.local_rank,
        deepspeed=ds_config if config.using_deepspeed else None,
    )

    trainer = Trainer(
        config=config,
        model=model,
        teacher=teacher,
        args=trainer_config,
        train_dataset=get_hf_dataset(datamodule.train_data),
        eval_dataset=get_hf_dataset(datamodule.val_data),
        tokenizer=datamodule.tokenizer,
        data_collator=DataCollatorForSeq2Seq(
            tokenizer=datamodule.tokenizer,
            padding='longest',
            pad_to_multiple_of=config.seq_len,
        ),
    )

    trainer.train(resume_from_checkpoint=config.load_from_checkpoint)


def run():
    seed()
    main(parse_args())


if __name__ == "__main__":
    run()
