from functools import partial
from pathlib import Path
import os
import json
import torch
import numpy as np
from datasets import disable_caching, DatasetDict
import dataclasses
from transformers import (
    AutoTokenizer,
    LlamaConfig,
    LlamaForCausalLM,
    Gemma2ForCausalLM,
    Gemma2Model,
)
from accelerate import Accelerator

from latte_trans.trainer.torch_trainer import Trainer
from latte_trans.experiments.base import BaseTask
from latte_trans.evals.lang_eval_pt import LanguageEvaluator
from latte_trans.experiments.utils import parse_args
from latte_trans.config import LMTaskConfig
from latte_trans.preproc.tiny_stories import TinyStories
from latte_trans.models.tasks.lm_pret_pt import HuggWrapper, HuggWrapper2
from latte_trans.preproc.slim_pajama import SlimPajama
from latte_trans.preproc.book_corpus import BookCorpusLong
from latte_trans.preproc.openweb import OpenWeb
from latte_trans.preproc.lm_dp import Wiki103DP


def get_dp(config, force_preproc):
    path_name = config.hugg_chk if config.hugg_chk else "allenai/longformer-base-4096"
    print("Path Name is: ", path_name)
    tokenizer = AutoTokenizer.from_pretrained(
        path_name,
        cache_dir=Path(config.base_dir) / "input/cache_hugg",
        truncation_side="right",
        padding_side="right",
    )
    if config.dataset_name == "tiny-stories":
        cache_dir = os.path.join(config.base_dir, "input", "tiny-sories")
        dp = TinyStories(
            tokenizer=tokenizer,
            cache_dir=cache_dir,
            max_seq_len=config.max_seq_len,
            num_load_procs=max(1, os.cpu_count() - 1),
        )
        raw_data = dp.get_raw_data()
        tok_data = dp.tokenize(raw_data, force_preproc=force_preproc)
    elif config.dataset_name == "pajama":
        tokenizer = AutoTokenizer.from_pretrained(
            "google/gemma-2-2b",
            cache_dir=os.path.join(config.base_dir, "input/cache_hugg"),
            truncation_side="right",
            padding_side="right",
        )
        # We directly load the tokenized data. Slow to tokenize on the fligh
        raw_data = None
        tok_data = DatasetDict.load_from_disk(
            "/user_all_data/data/input/pajama_tokenized"
        )
        tok_data.set_format("torch")
        # dummy data-processor only used for collate_fn
        dp = SlimPajama(
            tokenizer,
            cache_dir=None,
            num_load_procs=None,
            max_seq_len=config.max_seq_len,
        )
    elif config.dataset_name == "bookcorpus":
        # only for evaluation purposes
        tokenizer = AutoTokenizer.from_pretrained(
            "google/gemma-2-2b",
            cache_dir=os.path.join(config.base_dir, "input/cache_hugg"),
            truncation_side="right",
            padding_side="right",
        )
        data_dir = config.base_dir
        cache_dir = os.path.join(data_dir, "input/bookcorpus")
        dp = BookCorpusLong(
            tokenizer,
            cache_dir=cache_dir,
            num_load_procs=8,
            max_seq_len=config.max_seq_len,
        )
        raw_data = dp.get_raw_data()
        tok_data = dp.tokenize(raw_data, force_tok=config.disable_cache)
        tok_data.set_format("torch")
    elif config.dataset_name == "owt":
        # only for evaluation purposes
        cache_dir = os.path.join(config.base_dir, "input", "test_openweb")
        tokenizer = AutoTokenizer.from_pretrained(
            "google/gemma-2-2b",
            cache_dir=os.path.join(config.base_dir, "input/cache_hugg"),
            truncation_side="right",
            padding_side="right",
        )
        dp = OpenWeb(
            tokenizer=tokenizer,
            cache_dir=cache_dir,
            num_load_procs=8,
            max_seq_len=config.max_seq_len,
        )
        raw_data = dp.get_raw_data()
        # add a small sample for compatibility reasons
        raw_data["train"] = raw_data["train"].select(np.arange(1000))
        tok_data = dp.tokenize(raw_data, force_tok=config.disable_cache)
        tok_data.set_format("torch")

    elif config.dataset_name == "wiki103":
        cache_dir = os.path.join(config.base_dir, "input", "test_wikitext-103")
        tokenizer = AutoTokenizer.from_pretrained(
            "google/gemma-2-2b",
            cache_dir=os.path.join(config.base_dir, "input/cache_hugg"),
            truncation_side="right",
            padding_side="right",
        )
        # tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
        # tokenizer = SpecialToksGPT2TokenizerFast.from_pretrained("EleutherAI/gpt-neo-125M")
        # SpecialToksGPT2TokenizerFast.from_pretrained("gpt2")
        dp = Wiki103DP(
            tokenizer=tokenizer,
            cache_dir=cache_dir,
            max_seq_len=config.max_seq_len,
        )
        raw_data = dp.get_raw_data()
        tok_data = dp.tokenize(
            raw_data, add_special_tokens=False, force_tok=config.disable_cache
        )
    return dp, tokenizer, raw_data, tok_data


def get_from_hugg_config(config: LMTaskConfig):
    from transformers import AutoConfig

    hugg_config = AutoConfig.from_pretrained(
        config.hugg_chk,
        cache_dir=Path(config.base_dir) / "input/cache_hugg",
    )

    config = config.replace(
        hidden_dim=hugg_config.hidden_size,
        intermediate_dim=hugg_config.intermediate_size,
        hidden_act=hugg_config.hidden_act,
        nlayers=hugg_config.num_hidden_layers,
        num_key_value_heads=hugg_config.num_key_value_heads,
        pos_embed_max_len=hugg_config.max_position_embeddings,
        nheads=hugg_config.num_attention_heads,
        head_dim=hugg_config.head_dim,
        dropout_att=hugg_config.attention_dropout,
        attention_bias=hugg_config.attention_bias,
        text_vocab_size=hugg_config.vocab_size,
    )
    return config


def get_pretrained_model(config, dp):
    tokenizer = dp.tokenizer
    match config.block_type:
        case "llama":
            model = HuggWrapper(
                LlamaForCausalLM.from_pretrained(
                    config.hugg_chk,
                    cache_dir=Path(config.base_dir) / "input/cache_hugg",
                ),
                pad_id=tokenizer.pad_token_id,
            )
        case "gemma-hugg":
            model = Gemma2Model.from_pretrained(
                config.hugg_chk,
                attn_implementation="eager",
                cache_dir=Path(config.base_dir) / "input/cache_hugg",
                tie_word_embeddings=True,
            )
            model = HuggWrapper2(
                model,
                embed=model.get_input_embeddings(),
                pad_id=tokenizer.pad_token_id,
            )
        case "gemma":
            from latte_trans.models.tasks.lm_pret_pt import Gemma

            weights = Gemma2Model.from_pretrained(
                config.hugg_chk,
                attn_implementation="eager",
                cache_dir=Path(config.base_dir) / "input/cache_hugg",
                tie_word_embeddings=True,
            ).state_dict()
            config = get_from_hugg_config(config)
            model = Gemma(config, pad_id=tokenizer.pad_token_id)
            weights = model._load(weights, model.state_dict())
            model.load_state_dict(weights)
    return model


class LMPretTask(BaseTask):
    def __init__(self, config) -> None:
        self.config = config
        self.report_to = "none"
        self.wandb_run = None

        self.out_dir = os.path.join(self.config.base_dir, "out_latte", self.config.name)
        os.makedirs(self.out_dir, exist_ok=True)
        # dump config file in model dir for debug
        with open(os.path.join(self.out_dir, "config.json"), "w+") as f:
            a = dataclasses.asdict(config)
            json.dump(a, f)

        self._accelerator = Accelerator()
        if self._accelerator.is_main_process:
            print("Config is ", config)

    def safe_data_load(self):
        # first process tokenizes if necessary, the rest only read
        if self._accelerator.is_main_process:
            dp, tokenizer, raw_data, tokenized_data = get_dp(
                self.config, self.config.disable_cache
            )
        else:
            self._accelerator.wait_for_everyone()
        # data was pre-rpecessed
        if not self._accelerator.is_main_process:
            dp, tokenizer, raw_data, tokenized_data = get_dp(
                self.config, force_preproc=False
            )
        else:
            self._accelerator.wait_for_everyone()
        return dp, tokenizer, raw_data, tokenized_data

    def train(self):
        dp, tokenizer, raw_data, tokenized_data = self.safe_data_load()
        data_collator = dp.get_collate_fn(return_type="torch")
        model = get_pretrained_model(self.config, dp)

        evaluator = LanguageEvaluator(
            model,
            tokenizer,
            tokenized_data["validation"],  # .select(np.arange(320)),
            data_collator=data_collator,
            config=self.config,
        )

        # inputs_order = ("input_ids", "attention_mask", "labels")
        inputs_order = ("input_ids", "labels")
        trainer = Trainer(
            config=self.config,
            out_dir=self.out_dir,
            model=model,
            train_data=tokenized_data["train"],
            data_collator=data_collator,
            evaluator=evaluator,
            model_inputs_orded=inputs_order,
        )
        if not self.config.check_path is None:
            trainer.train(self.config.check_path)
        else:
            trainer.train()

    def evaluate(self):
        dp, tokenizer, raw_data, tokenized_data = self.safe_data_load()
        data_collator = dp.get_collate_fn(return_type="torch")
        model = get_pretrained_model(self.config, dp)

        print(tokenized_data)
        evaluator = LanguageEvaluator(
            model,
            tokenizer,
            tokenized_data["validation"],  # .select(np.arange(320)),
            data_collator=data_collator,
            config=self.config,
        )

        # inputs_order = ("input_ids", "attention_mask", "labels")
        inputs_order = ("input_ids", "labels")
        trainer = Trainer(
            config=self.config,
            out_dir=self.out_dir,
            model=model,
            train_data=tokenized_data["train"],
            data_collator=data_collator,
            evaluator=evaluator,
            model_inputs_orded=inputs_order,
        )
        trainer.prepare_train()
        eval_fn = partial(trainer.eval_step)
        scores = evaluator.evaluate(
            trainer_eval_fn=eval_fn, prefix="eval_", accelerator=trainer._accelerator
        )
        print("Scores loaded: ", scores)


def main():
    args = parse_args()
    config = LMTaskConfig.load(
        yaml_file=args.config_file, base_dir=args.base_dir, name=args.name
    )

    if config.disable_cache:
        # logger.info("Disabling Cache", main_process_only=True)
        disable_caching()

    task = LMPretTask(config)
    if args.evaluate:
        task.evaluate()
    else:
        task.train()


if __name__ == "__main__":
    main()
