import argparse
from typing import Dict, Tuple
import functools
import jax
from jax import numpy as jnp
import logging

from transformers import BatchEncoding
from datasets import disable_caching
from latte_trans.preproc.copy import get_tokenizer, get_train_dataset, get_eval_dataset
from latte_trans.trainer.jax_single_host import Trainer
from latte_trans.models.tasks.copy import LMHeadCopy
from latte_trans.experiments.base import BaseTask
from latte_trans.config import CopyTaskConfig
from latte_trans.experiments.utils import parse_args
from latte_trans.evals.copy import CopyEvaluator

logging.basicConfig(
    format="%(asctime)s,%(msecs)d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
    datefmt="%Y-%m-%d:%H:%M:%S",
    level=logging.INFO,
)
LOG = logging.getLogger(__name__)
# jax.config.update('jax_platform_name', 'cpu')


def eval_step(
    batchnorm: bool,
    model_rng: jax.random.PRNGKey,
    state,
    batch: Tuple[jax.Array],
) -> Dict[str, jax.Array]:
    dropout_train_key = jax.random.fold_in(key=model_rng, data=state.step)
    params = state.params
    batch = jax.lax.stop_gradient(batch)
    if batchnorm:
        output, updates = state.apply_fn(
            {"params": params, "batch_stats": state.batch_stats},
            *batch,
            train=False,
            mutable=["batch_stats"],
        )
    else:
        output = state.apply_fn(
            {"params": params},
            *batch,
            train=False,
        )
    return output


class CopyTask(BaseTask):
    def get_model(self, tokenizer, ignore_idx):
        model = LMHeadCopy(
            self.config,
            vocab_size=len(tokenizer),
            pad_id=0,
            ignore_index=ignore_idx,
        )
        return model

    def train(self, train_rng):
        tokenizer, TO_TOKEN, TO_CHAR = get_tokenizer(self.config)
        train_dataset = get_train_dataset(self.config, tokenizer)
        model = self.get_model(tokenizer, ignore_idx=TO_TOKEN["*"])

        train_rng, init_rng, eval_rng = jax.random.split(train_rng, 3)
        # # LOG.info("Data sizes: %s", raw_data)
        # LOG.info("Tokenized Data: %s", tokenized_data)

        # LOG.info("Tokenized Data: %s", tokenized_data["train"][0]["input_ids"])
        # LOG.info(
        #     "Detokenized Data: %s %s",
        #     tokenizer.bos_token_id,
        #     tokenizer.decode(
        #         tokenized_data["train"][0]["input_ids"],
        #         decode_tok=True,
        #         skip_special_tokens=False,
        #     ),
        # )

        evaluator = CopyEvaluator(
            model=model, tokenizer=tokenizer, TO_TOKEN=TO_TOKEN, config=self.config
        )
        trainer = Trainer(
            config=self.config,
            out_dir=self.out_dir,
            model=model,
            train_data=None,
            train_dl=train_dataset,
            data_collator=None,
            evaluator=evaluator,
            wandb_run=self.wandb_run,
            rng=init_rng,
            model_inputs_orded=("input_ids", "mask", "labels"),
        )
        if not self.config.check_path is None:
            trainer.train(train_rng, self.config.check_path)
        else:
            trainer.train(train_rng)

    def evaluate(self, train_rng):
        _eval_step_fn = jax.jit(eval_step, static_argnums=(0,))

        def trainer_eval(
            batchnorm,
            state,
            rng: jax.random.PRNGKey,
            batch: dict[str, jax.Array],
        ) -> Tuple[jax.Array, Dict[str, jax.Array]]:
            """
            Places data on correct device and calls the model on the batch
            """
            inputs = tuple([batch[k] for k in ("input_ids", "mask", "labels")])
            inputs = jax.lax.stop_gradient(inputs)
            output = _eval_step_fn(
                batchnorm,
                rng,
                state,
                inputs,
            )
            labels = jax.device_get(batch["labels"])
            output = jax.device_get(output)
            return labels, output

        train_rng, init_rng, eval_rng = jax.random.split(train_rng, 3)
        tokenizer, TO_TOKEN, TO_CHAR = get_tokenizer(self.config)
        model = self.get_model(tokenizer, ignore_idx=TO_TOKEN["*"])
        evaluator = CopyEvaluator(
            model=model, tokenizer=tokenizer, TO_TOKEN=TO_TOKEN, config=self.config
        )

        # chk_path = "/home/user/latte_trans/data/out_latte/copy_50_latte_vapor_1e_4_16/checkpoints"
        chk_path = (
            "/home/user/latte_trans/data/out_latte/copy_50_latte_vapor128/checkpoints"
        )
        data = {
            "input_ids": jnp.zeros((2, self.config.max_seq_len), dtype=jnp.int32),
            "mask": jnp.zeros((2, self.config.max_seq_len), dtype=jnp.int32),
            "labels": jnp.zeros((2, self.config.max_seq_len), dtype=jnp.int32),
        }
        data = tuple(data.values())  # from BatchEncoding to tuple

        zero_state = Trainer.create_zero_state(init_rng, data, self.config, model)
        state, meta = Trainer.load_trainer_state(
            zero_state, check_dir=chk_path, step_number=None
        )
        eval_fn = functools.partial(
            trainer_eval, self.config.batchnorm, state, eval_rng
        )
        scores = evaluator.evaluate(
            trainer_eval_fn=eval_fn, prefix="eval_", state=state
        )
        print(scores)


def main():
    # args = parse_args()

    # print(args)

    # ## Get train dataset & tokenizer
    # tokenizer, TO_TOKEN, TO_CHAR = get_tokenizer(args)
    # train_dataset = get_train_dataset(args,tokenizer)

    # batch = next(iter(train_dataset))

    # print("-"*100)
    # print(f"EXAMPLE {batch['input'][0]}")
    # print("-"*100)
    # print(batch['input_ids'][-1][batch['mask'][-1]==1], batch['input_ids'][-1], batch['input'][-1])
    # print("*"*100)

    # print("+"*100)
    # print(batch["input_ids"].shape)
    # print(batch["mask"][-1])
    # print(batch)

    seed = 0
    rng = jax.random.PRNGKey(seed)
    rng, train_rng, sample_rng = jax.random.split(rng, 3)
    args = parse_args()
    config = CopyTaskConfig.load(
        yaml_file=args.config_file, base_dir=args.base_dir, name=args.name
    )

    if config.disable_cache:
        LOG.info("Disabling Cache")
        disable_caching()

    task = CopyTask(config)
    task.train(train_rng)
    # task.evaluate(train_rng)


if __name__ == "__main__":
    """
    pdm run python3 -m latte_trans.copy --model "T_rope" --train_task "copy" --eval_task  "copy" --min_train_len 5 --max_train_len 20 --min_eval_len 20 --max_eval_len 20
    """
    main()
