from typing import Dict, Tuple, Callable
from pathlib import Path
from functools import partial
import math
import os
import json
import logging
import wandb
import jax
from jax import numpy as jnp
from torch.utils.data import DataLoader
import numpy as np
from datasets import disable_caching
from transformers import (
    AutoTokenizer,
    AutoConfig,
)

from latte_trans.trainer.dfsp_jax import (
    DFSDPTrainer,
    gather_params,
    shard_params,
    shard_module_params,
    BatchNormTrainState,
    TrainState,
)
from latte_trans.experiments.utils import parse_args
from latte_trans.experiments.base import BaseTask
from latte_trans.preproc.scrolls import ScrollsDP
from latte_trans.config import LMTaskConfig
from latte_trans.evals.scrolls_eval import ScrollsEvaluator
from latte_trans.models.tasks.scrolls import DecoderOnlyScrolls

from flax import linen as nn
import optax
from flax.training import train_state
import flax

MIN_SHARD_WEIGHT = 2**12

# ["gov_report", "summ_screen_fd", "qmsum","narrative_qa", "qasper", "quality", "contract_nli"]
DATA_LENGTH = {
    "gov_report": {"max_seq_len":8000, "max_tgt_len":500},
    "summ_screen_fd": {"max_seq_len":6000, "max_tgt_len":128},
    "qmsum": {"max_seq_len":9500, "max_tgt_len":100},
    "narrative_qa": {"max_seq_len":52000, "max_tgt_len":16},
    "qasper": {"max_seq_len":4000, "max_tgt_len":16},
    "quality": {"max_seq_len":4500, "max_tgt_len":16},
    "contract_nli": {"max_seq_len":2000, "max_tgt_len":8},
}

def get_dp(config):
    assert config.dataset_name in DATA_LENGTH, f"Dataset must be in one of {DATA_LENGTH.keys()}"
    path_name = config.hugg_chk if config.hugg_chk else "google/gemma-2-2b"
    print("Path Name is: ", path_name)
    tokenizer = AutoTokenizer.from_pretrained(
        path_name,
        cache_dir=Path(config.base_dir) / "input/cache_hugg",
        truncation_side="right",
        padding_side="right",
    )
    
    cache_dir = os.path.join(config.base_dir, "input", "scrolls", config.dataset_name)
    # truncate to short if necessary
    if config.max_seq_len > 0:
        DATA_LENGTH[config.dataset_name]["max_seq_len"] = config.max_seq_len

    dp = ScrollsDP(
        tokenizer=tokenizer,
        cache_dir=cache_dir,
        max_seq_len=DATA_LENGTH[config.dataset_name]["max_seq_len"],
        max_tgt_len=DATA_LENGTH[config.dataset_name]["max_tgt_len"],
        num_load_procs=min(1, 4),
    )
    raw_data = dp.get_raw_data(subdataset=config.dataset_name, sample=100)
    tok_data = dp.tokenize(raw_data, force_preproc=config.disable_cache).with_format("np")
    print("TOK data is: ",  tok_data)
    return dp, tokenizer, raw_data, tok_data


def get_from_hugg_config(config: LMTaskConfig):
    hugg_config = AutoConfig.from_pretrained(
        config.hugg_chk,
        cache_dir=Path(config.base_dir) / "input/cache_hugg",
    )

    config = config.replace(
        hidden_dim=hugg_config.hidden_size,
        intermediate_dim=hugg_config.intermediate_size,
        hidden_act=hugg_config.hidden_act,
        nlayers=hugg_config.num_hidden_layers,
        num_key_value_heads=hugg_config.num_key_value_heads,
        pos_embed_max_len=hugg_config.max_position_embeddings,
        nheads=hugg_config.num_attention_heads,
        head_dim=hugg_config.head_dim,
        dropout_att=hugg_config.attention_dropout,
        attention_bias=hugg_config.attention_bias,
        text_vocab_size=hugg_config.vocab_size,
    )
    print("ACT: ", hugg_config.hidden_activation)
    return config

def update_config(config, vocab_size):
    config = config.replace(text_vocab_size=vocab_size + 2,
                            head_dim=config.hidden_dim//config.nheads,
                            )
    return config
class ScrollTask(BaseTask):
    def get_model(self, dp: ScrollsDP, sharded: bool = True):
        tokenizer = dp.tokenizer
        if self.config.hugg_chk:
            config = get_from_hugg_config(self.config)
        else:
            config = self.config
            config = update_config(config, tokenizer.vocab_size)
            

        match self.config.mixed_precision:
            case "bf16":
                mixed_precission = jnp.bfloat16
            case _:
                mixed_precission = jnp.float32


        if sharded:
            sharded_dense = shard_module_params(
                DecoderOnlyScrolls, axis_name="B", min_weight_size=MIN_SHARD_WEIGHT
            )
        else:
            sharded_dense = DecoderOnlyScrolls

        model = sharded_dense(
            config,
            pad_id=tokenizer.pad_token_id,
            dtype=mixed_precission,
            sharded=sharded,
        )

        print(model)
        return model

    def train(self, train_rng):
        dp, tokenizer, raw_data, tokenized_data = get_dp(self.config)
        data_collator = dp.get_collate_fn(return_type="np")

        model = self.get_model(
            dp, sharded=True
        )  # sharded=False) # get_lava_model(self.config, dp.tokenizer)  #
        train_rng, init_rng, eval_rng = jax.random.split(train_rng, 3)

        evaluator = ScrollsEvaluator(
            tokenizer,
            tokenized_data["validation"],
            raw_data=raw_data["validation"],
            data_collator=data_collator,
            config=self.config,
            rng=eval_rng,
        )
        trainer = DFSDPTrainer(
            config=self.config,
            out_dir=self.out_dir,
            model=model,
            train_data=tokenized_data["train"],
            data_collator=data_collator,
            evaluator=evaluator,
            wandb_run=self.wandb_run,
            rng=init_rng,
            model_inputs_orded=("input_ids", "labels"),
        )
        if not self.config.check_path is None:
            trainer.train(train_rng, self.config.check_path)
        else:
            trainer.train(train_rng)

def main():
    seed = 0
    rng = jax.random.PRNGKey(seed)
    rng, train_rng, sample_rng = jax.random.split(rng, 3)
    args = parse_args()
    config = LMTaskConfig.load(
        yaml_file=args.config_file, base_dir=args.base_dir, name=args.name
    )

    if config.disable_cache:
        print("Disabling Cache")
        disable_caching()

    task = ScrollTask(config)
    if args.evaluate:
        task.evaluate(sample_rng)
        # task.check_latens_collapse(sample_rng)
    else:
        task.train(train_rng)
        # task.train_formpt(train_rng)


if __name__ == "__main__":
    main()
