from functools import partial
import math
import os
import json
import logging
import wandb
import torch
import numpy as np
from datasets import disable_caching
import dataclasses
from transformers import (
    GPTNeoForCausalLM,
    GPTNeoConfig,
    OpenAIGPTLMHeadModel,
    OpenAIGPTConfig,
    LlamaForCausalLM,
    LlamaConfig,
    MegaConfig,
    MegaForCausalLM,
)
from torchscale.architecture.config import RetNetConfig
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel, MambaConfig
from zoology.config import ModelConfig, ModuleConfig
from zoology.model import LanguageModel

from latte_trans.trainer.torch_trainer import Trainer
from latte_trans.experiments.base import BaseTask
from latte_trans.models.tasks.lm_torch import (
    HuggWrapper,
    MambaWrapper,
    LinearTransWrapper,
    LMHeadVanilla,
    RetNetWrapper,
    ZoologyWrapper,
)
from latte_trans.evals.lang_eval_pt import LanguageEvaluator
from latte_trans.experiments.utils import parse_args
from latte_trans.config import LMTaskConfig
from latte_trans.preproc.pile import PileStream


def get_dp(config):
    from transformers import AutoTokenizer
    from latte_trans.preproc.lm_dp import Wiki103DP
    from latte_trans.preproc.openweb import OpenWeb, OpenWebTextDP2
    from datasets import DatasetDict
    from latte_trans.preproc.toks import SpecialToksGPT2TokenizerFast
    from latte_trans.preproc.tiny_stories import TinyStories

    if config.dataset_name == "wiki103":
        cache_dir = os.path.join(config.base_dir, "input", "wikitext-103")
        tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
        dp = Wiki103DP(tokenizer=tokenizer, cache_dir=cache_dir)
        raw_data = dp.get_raw_data()
        tok_data = dp.tokenize(raw_data, add_special_tokens=True)
    elif config.dataset_name == "pile":
        cache_dir = os.path.join(config.base_dir, "input", "pile")
        tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
        dp = PileStream(tokenizer=tokenizer, cache_dir=cache_dir, num_load_procs=None)
        raw_data = dp.get_raw_data()
        tok_data = dp.tokenize(raw_data, max_seq_len=config.max_seq_len)
    elif config.dataset_name == "owt":
        cache_dir = os.path.join(config.base_dir, "input", "tok_openweb")
        tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
        dp = OpenWeb(
            tokenizer=tokenizer,
            cache_dir=cache_dir,
            num_load_procs=None,
            max_seq_len=config.max_seq_len,
        )
        raw_data = None
        tok_data = DatasetDict.load_from_disk(cache_dir)
        tok_data.set_format("pt")
    elif config.dataset_name == "owt2":
        cache_dir = os.path.join(config.base_dir, "input", "openweb_mask")
        tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")
        dp = OpenWebTextDP2(
            tokenizer=tokenizer,
            cache_dir=cache_dir,
            num_load_procs=8,
            max_seq_len=config.max_seq_len,
        )
        raw_data = dp.get_raw_data()
        tok_data = dp.tokenize(raw_data)
        # logger.info(tok_data, main_process_only=True)
    elif config.dataset_name == "tiny-stories":
        cache_dir = os.path.join(config.base_dir, "input", "tiny-sories")
        tokenizer = SpecialToksGPT2TokenizerFast.from_pretrained("gpt2")
        dp = TinyStories(
            tokenizer=tokenizer,
            cache_dir=cache_dir,
            max_seq_len=config.max_seq_len,
            num_load_procs=min(1, os.cpu_count() - 1),
        )
        raw_data = dp.get_raw_data()
        tok_data = dp.tokenize(raw_data, force_preproc=config.disable_cache)
    return dp, tokenizer, raw_data, tok_data


def get_zoology_mixer(config: LMTaskConfig):
    input_seq_len = config.max_seq_len
    MIXERS = {
        "attention": dict(
            name="zoology.mixers.attention.MHA",
            kwargs={"dropout": 0.1, "num_heads": 1},
        ),
        "hyena": dict(
            name="zoology.mixers.hyena.Hyena",
            kwargs={
                "l_max": input_seq_len,
                "d_model": config.hidden_dim,
                "num_heads": config.nheads,
                "dropout": config.dropout,
                "filter_order": config.state_dim,
            },
        ),
        "rwkv": dict(
            name="zoology.mixers.rwkv.RWKVTimeMixer",
            kwargs={
                "l_max": input_seq_len,
            },
        ),
        "base_conv": dict(
            name="zoology.mixers.base_conv.BaseConv",
            kwargs={
                "l_max": input_seq_len,
                # pass a list of kernel sizes for each of four layers
                "kernel_size": [3, -1, 3, -1],
            },
        ),
        "base_conv_explicit": dict(
            name="zoology.mixers.base_conv.BaseConv",
            kwargs={
                "l_max": input_seq_len,
                # pass a list of kernel sizes for each of four layers
                "kernel_size": [3, -1, 3, -1],
                "implicit_long_conv": True,
            },
        ),
        "h3": dict(
            name="zoology.mixers.h3.H3",
            kwargs={
                "l_max": input_seq_len,
                "d_state": config.state_dim,  # makes it mathematically equivalent to Hyena
                "d_model": config.hidden_dim,
                "head_dim": config.hidden_dim // config.nheads,
            },
        ),
    }
    return ModuleConfig(**MIXERS[config.block_type])


class LMTask(BaseTask):
    def __init__(self, config) -> None:
        print("Config is ", config)
        self.config = config
        self.report_to = "none"
        self.wandb_run = None

        self.out_dir = os.path.join(self.config.base_dir, "out_latte", self.config.name)
        os.makedirs(self.out_dir, exist_ok=True)
        # dump config file in model dir for debug
        with open(os.path.join(self.out_dir, "config.json"), "w+") as f:
            a = dataclasses.asdict(config)
            json.dump(a, f)

    def get_model(self, config, tokenizer):
        match config.block_type:
            case "gptneo":
                config = GPTNeoConfig(  # GPTNeoXConfig(
                    bos_token_id=tokenizer.bos_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    max_position_embeddings=config.max_seq_len,
                    hidden_size=config.hidden_dim,
                    intermediate_size=config.hidden_dim * 4,
                    num_attention_heads=config.nheads,
                    num_hidden_layers=config.nlayers,
                    vocab_size=tokenizer.vocab_size,
                )
                base_model = GPTNeoForCausalLM(config)
                model = HuggWrapper(base_model, pad_id=tokenizer.pad_token_id)
            case "openai-gpt":
                config = OpenAIGPTConfig(  # GPTNeoXConfig(
                    bos_token_id=tokenizer.bos_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    pad_token_id=tokenizer.pad_token_id,
                    n_positions=config.max_seq_len,
                    hidden_size=config.hidden_dim,
                    n_embd=config.hidden_dim,
                    intermediate_size=config.hidden_dim * 4,
                    num_attention_heads=config.nheads,
                    num_hidden_layers=config.nlayers,
                    vocab_size=tokenizer.vocab_size,
                )
                base_model = OpenAIGPTLMHeadModel(config)
                model = HuggWrapper(base_model, pad_id=tokenizer.pad_token_id)
            case "lamma":
                config = LlamaConfig(  # GPTNeoXConfig(
                    bos_token_id=tokenizer.bos_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    pad_token_id=tokenizer.pad_token_id,
                    n_positions=config.pos_embed_max_len,
                    hidden_size=config.hidden_dim,
                    n_embd=config.hidden_dim,
                    intermediate_size=config.hidden_dim * 4,
                    num_attention_heads=config.nheads,
                    num_hidden_layers=config.nlayers,
                    vocab_size=tokenizer.vocab_size,
                    rope_scaling={"type": "linear", "factor": 8.0},
                )
                base_model = LlamaForCausalLM(config)
                model = HuggWrapper(base_model, pad_id=tokenizer.pad_token_id)

            case "mamba":
                config = MambaConfig(
                    d_model=config.hidden_dim,
                    d_intermediate=4 * config.hidden_dim,
                    n_layer=config.nlayers,
                    ssm_cfg={"d_state": config.state_dim, "expand": 1},
                    vocab_size=tokenizer.vocab_size,
                )
                base_model = MambaLMHeadModel(config=config)
                param_count = sum(
                    p.numel()
                    for p in base_model.backbone.parameters()
                    if p.requires_grad
                )
                print(f"Number of parameters Backbone: {param_count / 1000000} M")
                model = MambaWrapper(base_model)
            case "linear-transformer":
                model = LinearTransWrapper(
                    d_model=self.config.hidden_dim,
                    sequence_length=config.pos_embed_max_len,
                    vocab_size=tokenizer.vocab_size,
                    attention_type="causal-linear",
                    d_query=config.hidden_dim // config.nheads,
                    dropout=config.dropout,
                    attention_dropout=config.dropout_att,
                    n_layers=self.config.nlayers,
                    n_heads=self.config.nheads,
                    pad_id=tokenizer.pad_token_id,
                )
            case "retnet":
                config = RetNetConfig(
                    max_target_positions=config.pos_embed_max_len,
                    decoder_embed_dim=config.hidden_dim,
                    decoder_value_embed_dim=config.hidden_dim,
                    decoder_ffn_embed_dim=config.hidden_dim * 4,
                    decoder_layers=config.nlayers,
                    decoder_retention_heads=config.nheads,
                    vocab_size=tokenizer.vocab_size,
                    xpos_rel_pos=True,
                    xpos_scale_base=self.config.pos_embed_max_len,
                )
                model = RetNetWrapper(pad_id=tokenizer.pad_token_id, config=config)
            case "rwkv":
                mixer_config = get_zoology_mixer(config)
                model_config = ModelConfig(
                    vocab_size=tokenizer.vocab_size,
                    d_model=config.hidden_dim,
                    n_layers=config.nlayers,
                    embed_dropout=0.0,
                    resid_dropout=config.dropout,
                    max_position_embeddings=config.pos_embed_max_len,
                    sequence_mixer=mixer_config,
                    # block_type="TransformerBlock",
                    block_type="TransformerSota",
                )
                model = ZoologyWrapper(LanguageModel(config=model_config))

            case "mega":
                # Initializing a Mega configuration
                configuration = MegaConfig(
                    bos_token_id=tokenizer.bos_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                    pad_token_id=tokenizer.pad_token_id,
                    max_positions=config.pos_embed_max_len,
                    hidden_size=config.hidden_dim,
                    n_embd=config.hidden_dim,
                    intermediate_size=config.hidden_dim * 4,
                    num_attention_heads=config.nheads,
                    num_hidden_layers=config.nlayers,
                    nffn_hidden_size=config.hidden_dim,  # config.state_dim,
                    vocab_size=tokenizer.vocab_size,
                    chunk_size=config.att_block_len,
                    bidirectional=False,
                    use_chunking=True,
                    attention_probs_dropout_prob=config.dropout_att,
                    dropout_prob=config.dropout,
                    is_decoder=True,
                )

                # Initializing a model (with random weights) from the configuration
                base_model = MegaForCausalLM(configuration)
                model = HuggWrapper(base_model, pad_id=tokenizer.pad_token_id)
            case "transformer" | "transformer-sota":
                model = LMHeadVanilla(
                    config,
                    vocab_size=tokenizer.vocab_size,
                    pad_id=tokenizer.pad_token_id,
                )

        return model

    def train(self):
        dp, tokenizer, raw_data, tokenized_data = get_dp(self.config)
        data_collator = dp.get_collate_fn(return_type="torch")

        model = self.get_model(self.config, tokenizer)
        # LOG.info("Data sizes: %s", raw_data)
        # logger.info("Tokenized Data: %s", tokenized_data, main_process_only=True)

        # logger.info("Tokenized Data: %s", tokenized_data["train"][0]["input_ids"], main_process_only=True)
        # logger.info(
        #     "Detokenized Data: %s %s",
        #     tokenizer.bos_token_id,
        #     tokenizer.decode(
        #         tokenized_data["train"][0]["input_ids"],
        #         decode_tok=True,
        #         skip_special_tokens=False,
        #     ),
        #     main_process_only=True
        # )
        evaluator = LanguageEvaluator(
            model,
            tokenizer,
            tokenized_data["validation"],  # .select(np.arange(320)),
            data_collator=data_collator,
            config=self.config,
        )

        # inputs_order = ("input_ids", "attention_mask", "labels")
        inputs_order = ("input_ids", "labels")
        trainer = Trainer(
            config=self.config,
            out_dir=self.out_dir,
            model=model,
            train_data=tokenized_data["train"],
            data_collator=data_collator,
            evaluator=evaluator,
            model_inputs_orded=inputs_order,
        )
        if not self.config.check_path is None:
            trainer.train(self.config.check_path)
        else:
            trainer.train()


def main():
    args = parse_args()
    config = LMTaskConfig.load(
        yaml_file=args.config_file, base_dir=args.base_dir, name=args.name
    )

    if config.disable_cache:
        # logger.info("Disabling Cache", main_process_only=True)
        disable_caching()

    task = LMTask(config)
    if args.evaluate:
        task.evaluate()
    else:
        task.train()


if __name__ == "__main__":
    main()
