import argparse
import numpy as np
import yaml, os
import torch
from time import time
import datetime
from zoneinfo import ZoneInfo

# from trainer import Trainer
from loader.data import _load_data, SimpleDataCollator, PolynomialDataCollator, GPTDataCollator
from loader.model import load_model
from trainer.trainer_utils import compute_metrics, preprocess_logits_for_metrics, LimitStepsCallback
from utils.utils import count_cuda_devices


import warnings

warnings.filterwarnings("ignore", message="Was asked to gather along dimension 0")
warnings.filterwarnings("ignore", message="The PyTorch API of nested tensors is in prototype stage")


import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

# Fix random seed
torch.use_deterministic_algorithms(True)
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)


def get_parser():
    """
    Generate a parameters parser.
    """
    # parse parameters
    parser = argparse.ArgumentParser(description="Language transfer")

    # main parameters
    parser.add_argument("--data_path", type=str, default="./data/data_sum", help="Experiment dump path")
    parser.add_argument("--data_encoding", type=str, default="infix")
    parser.add_argument("--save_path", type=str, default="./dumped", help="Experiment dump path")
    parser.add_argument("--save_periodic", type=int, default=0, help="Save the model periodically (0 to disable)")
    parser.add_argument("--exp_name", type=str, default="debug", help="Experiment name")
    parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")
    parser.add_argument("--task", type=str, default="sum", help="Task name")

    # float16 / AMP API
    parser.add_argument("--fp16", type=bool, default=True, help="Run model with float16")
    parser.add_argument(
        "--amp",
        type=int,
        default=2,
        help="Use AMP wrapper for float16 / distributed / gradient accumulation. Level of optimization. -1 to disable.",
    )

    # model parameters
    parser.add_argument("--model", type=str, default="bart", help="Embedding layer size")
    parser.add_argument("--d_model", type=int, default=512, help="Embedding layer size")
    parser.add_argument("--dim_feedforward", type=int, default=2048, help="feedforward layer size")
    parser.add_argument("--num_encoder_layers", type=int, default=6, help="Number of Transformer layers in the encoder")
    parser.add_argument("--num_decoder_layers", type=int, default=6, help="Number of Transformer layers in the decoder")
    parser.add_argument("--nhead", type=int, default=8, help="Number of Transformer heads")
    parser.add_argument("--dropout", type=float, default=0, help="Dropout")
    parser.add_argument("--attention_dropout", type=float, default=0, help="Dropout in the attention layer")
    # parser.add_argument("--continuous_embedding_model", type=str, default='ffn')
    parser.add_argument("--encoding_method", type=str, default="standard")
    parser.add_argument("--token_register_size", type=int, default=2)
    parser.add_argument("--positional_encoding", type=str, default="sinusoidal", choices=["sinusoidal", "embedding"])

    # vocab and tokenizer parameters
    parser.add_argument("--num_variables", type=int, default=2)
    parser.add_argument("--field", type=str, default="QQ", help="QQ or GFP with some integer P (e.g., GF7).")
    parser.add_argument("--max_coefficient", type=int, default=1000, help="The maximum coefficients")
    parser.add_argument("--max_degree", type=int, default=10, help="The maximum degree")
    parser.add_argument("--gaussian_encoding_upper_bound", type=int, default=10, help="For Gaussian embedding.")

    # training parameters
    parser.add_argument("--max_sequence_length", type=int, default=4096, help="Maximum sequences length")
    parser.add_argument("--num_batch", type=int, default=128, help="Number of sentences per batch")
    parser.add_argument("--test_batch_size", type=int, default=256, help="Number of sentences per batch")
    parser.add_argument(
        "--optimizer", type=str, default="adamw_apex_fused", help="Optimizer (SGD / RMSprop / Adam, etc.)"
    )
    parser.add_argument("--learning_rate", type=float, default=0.0001, help="learning rate (default 0.0001)")
    parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay (default 0)")
    parser.add_argument("--clip_grad_norm", type=float, default=5, help="Clip gradients norm (0 to disable)")
    parser.add_argument("--epochs", type=int, default=5, help="number of epochs")
    parser.add_argument("--max_steps_per_epoch", type=int, default=-1, help="Maximum epoch size")
    parser.add_argument("--warmup_ratio", type=float, default=0.0)
    parser.add_argument(
        "--stopping_criterion",
        type=str,
        default="",
        help="Stopping criterion, and number of non-increase before stopping the experiment",
    )
    parser.add_argument("--validation_metrics", type=str, default="", help="Validation metrics")
    parser.add_argument(
        "--accumulate_gradients",
        type=int,
        default=1,
        help="Accumulate model gradients over N iterations (N times larger batch sizes)",
    )
    parser.add_argument("--num_workers", type=int, default=8, help="Number of CPU workers for DataLoader")
    parser.add_argument("--deepspeed", type=str, default="")
    # environment parameters
    parser.add_argument("--env_name", type=str, default="char_sp", help="Environment name")

    parser.add_argument("--resume_from_checkpoint", action="store_true", default=False)

    # CPU / multi-gpu / multi-node
    parser.add_argument("--cpu", type=bool, default=False, help="Run on CPU")
    parser.add_argument("--local_rank", type=int, default=-1, help="Multi-GPU - Local rank")
    parser.add_argument("--master_port", type=int, default=-1, help="Master port (for multi-node SLURM jobs)")

    parser.add_argument("--dryrun", action="store_true", default=False)

    ## Dev
    parser.add_argument("--regression_weights", nargs="*", type=float)
    parser.add_argument("--continuous_coefficient", action="store_true", default=False)
    parser.add_argument("--continuous_exponent", action="store_true", default=False)
    parser.add_argument("--support_learning", action="store_true", default=False)
    parser.add_argument("--num_memory_tokens", type=int, default=16, help="Number of memory tokens")
    parser.add_argument("--use_memory_transformer", action="store_true", default=False)
    parser.add_argument("--use_fix_decoder_self_attn", action="store_true", default=False)
    parser.add_argument("--sparsity_lambda", type=float, default=0.0)

    return parser


def main():
    import wandb

    # from transformers import TrainingArguments
    from trainer.trainer import CustomTrainer as Trainer
    from trainer.trainer import CustomTrainingArguments as TrainingArguments

    parser = get_parser()
    params = parser.parse_args()

    os.makedirs(params.save_path, exist_ok=True)

    ## Load data
    trainset = _load_data(f"{params.data_path}.train")
    testset = _load_data(f"{params.data_path}.test")
    # testset.input = testset.input[:16]
    # testset.target = testset.target[:16]

    if params.dryrun:
        trainset.input = trainset.input[:100]
        trainset.target = trainset.target[:100]
        testset.input = testset.input[:10]
        testset.target = testset.target[:10]
        params.epochs = 10
        params.save_path = os.path.join(os.path.dirname(params.save_path), "dryrun")
        params.exp_name = "dryrun"

    ## Load model
    ### standard embedding
    if "standard" in params.encoding_method:
        from data.tokenizers import set_tokenizer, set_vocab

        vocab = set_vocab(
            params.num_variables,
            field=params.field,
            max_coeff=params.max_coefficient,
            max_degree=params.max_degree,
            continuous_coefficient=False,
            continuous_exponent=False,
        )
        tokenizer = set_tokenizer(vocab)
        model = load_model(params, vocab=vocab, tokenizer=tokenizer)
        dc = SimpleDataCollator(tokenizer)
        if params.model == "gpt2":
            dc = GPTDataCollator(tokenizer)
        label_names = ["labels"]

        tokenizer.save_pretrained(os.path.join(params.save_path, "tokenizer.json"))

    else:
        ### polynomial embedding
        vocab_map = {"pad_token_id": 1, "bos_token_id": 2, "eos_token_id": 3, "sep_token_id": 4, "number_token_id": 0}
        model = load_model(params, vocab=vocab_map, tokenizer=None)
        dc = PolynomialDataCollator(
            num_variables=params.num_variables, method="monom-wise", vocab_map=vocab_map, tokenizer=None
        )
        label_names = ["labels", "labels_for_regression"]

    ## Save parameters
    with open(os.path.join(params.save_path, "params.yaml"), "w") as f:
        yaml.dump(vars(params), f)

    now = datetime.datetime.now(ZoneInfo("Asia/Tokyo"))
    datetime_str = now.strftime("%Y%m%d_%H%M%S")
    run_name = f"{params.exp_id}_{datetime_str}"

    ## Set up trainer
    trainer_config = TrainingArguments(
        output_dir=params.save_path,
        num_train_epochs=params.epochs,
        max_steps_per_epoch=params.max_steps_per_epoch,
        logging_steps=50,
        save_total_limit=1,
        dataloader_pin_memory=True,
        bf16=True,
        eval_steps=100,
        label_names=label_names,
        remove_unused_columns=False,
        per_device_train_batch_size=params.num_batch // count_cuda_devices(),
        eval_strategy="steps",
        report_to="wandb",
        disable_tqdm=True,
        run_name=run_name,
    )

    limit_steps_callback = LimitStepsCallback(max_steps_per_epoch=params.max_steps_per_epoch)

    _compute_metrics = lambda x: compute_metrics(x, ignore_index=tokenizer.pad_token_id)
    trainer = Trainer(
        args=trainer_config,
        model=model,
        train_dataset=trainset,
        eval_dataset=testset,
        data_collator=dc,
        # compute_metrics=_compute_metrics,
        # preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        callbacks=[limit_steps_callback],
    )
    # breakpoint()

    ## Run training
    wandb.init(project=params.exp_name, name=run_name, config=trainer_config)
    s = time()
    train_result = trainer.train()
    print(f"training time: [{time()-s:.1f} sec]")

    ## Evaluate
    # acc, df_gen = trainer.evaluate_test_greedy(tokenizer=tokenizer)
    acc, df_gen = trainer.evaluate_gpt2(tokenizer)
    df_gen.to_csv(os.path.join(params.save_path, "generated.csv"), index=False)

    # trainer.save_model()

    metrics = train_result.metrics
    dataset_metrics = trainer.evaluate(metric_key_prefix="test")
    metrics.update(dataset_metrics)
    metrics["test/accuracy"] = acc

    trainer.save_metrics("all", metrics)
    wandb.finish()


if __name__ == "__main__":
    main()
  main()
