import argparse
import numpy as np
import yaml, os
import torch
from time import time
import datetime
from zoneinfo import ZoneInfo

# from trainer import Trainer
from loader.data_collator import PermutationExperimentDataCollator, DynamicPrefixTargetPermutationCollator
from loader.model import load_model
from trainer.trainer_utils import compute_metrics, preprocess_logits_for_metrics, LimitStepsCallback
from utils.utils import count_cuda_devices
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from data.tokenizers import set_vocab, set_tokenizer
from typing import List, Dict, Union


import warnings
import logging

warnings.filterwarnings("ignore", message="Was asked to gather along dimension 0")
warnings.filterwarnings("ignore", message="The PyTorch API of nested tensors is in prototype stage")
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

# 乱数シードの固定
torch.use_deterministic_algorithms(True)
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)


class TextContinuationDataset(Dataset):
    def __init__(
        self, tokenizer: AutoTokenizer, file_path: str, max_length: int, data_has_colon_separator: bool = True
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data_has_colon_separator = data_has_colon_separator
        self.examples = []  # Stores the full text strings for tokenization in __getitem__
        logger.info(f"Loading data from {file_path}")
        try:
            with open(file_path, "r", encoding="utf-8") as f:
                for line_num, line in enumerate(f):
                    line = line.strip()
                    if not line:
                        continue

                    if self.data_has_colon_separator:
                        parts = line.split(":", 1)
                        if len(parts) == 2:
                            input_part = parts[0].strip()
                            target_part = parts[1].strip()
                            # For Causal LM, the full sequence is typically "INPUT : TARGET_PERMUTABLE <EOS>"
                            # The `input_prefix_len` for the collator should be the tokenized length of "INPUT : "
                            full_text = f"{input_part} : {target_part}"
                            self.examples.append(full_text)
                        else:
                            logger.warning(
                                f"Line {line_num+1} in {file_path} does not contain ':' separator. Treating as full line: '{line}'"
                            )
                            # If no colon, treat the whole line as one sequence. Permutation might be applied to a segment from start if prefix_len is non-zero.
                            self.examples.append(line)
                    else:
                        self.examples.append(line)
            logger.info(f"Loaded {len(self.examples)} examples from {file_path}.")
        except FileNotFoundError:
            logger.error(f"Data file not found: {file_path}")
            raise
        except Exception as e:
            logger.error(f"Error reading data file {file_path}: {e}")
            raise

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        full_text = self.examples[idx]
        # The collator handles tokenization now, so dataset returns raw text.
        if " : " in full_text:
            full_text = full_text.replace(
                " : ", " "
            )  # This might simplify things for tokenizer, but collator should handle specific formatting if needed.
        return {
            "text": full_text,
        }


def get_parser():
    """
    Generate a parameters parser.
    """
    # parse parameters
    parser = argparse.ArgumentParser(description="Language transfer")

    # main parameters
    parser.add_argument("--data_path", type=str, default="./data/data_sum", help="Experiment dump path")
    parser.add_argument("--data_encoding", type=str, default="infix")
    parser.add_argument("--save_path", type=str, default="./dumped", help="Experiment dump path")
    parser.add_argument("--save_periodic", type=int, default=0, help="Save the model periodically (0 to disable)")
    parser.add_argument("--exp_name", type=str, default="debug", help="Experiment name")
    parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")
    parser.add_argument("--task", type=str, default="sum", help="Task name")

    # float16 / AMP API
    parser.add_argument("--fp16", type=bool, default=True, help="Run model with float16")
    parser.add_argument(
        "--amp",
        type=int,
        default=2,
        help="Use AMP wrapper for float16 / distributed / gradient accumulation. Level of optimization. -1 to disable.",
    )

    # model parameters
    parser.add_argument("--model", type=str, default="gpt2", help="Embedding layer size")
    parser.add_argument("--d_model", type=int, default=512, help="Embedding layer size")
    parser.add_argument("--dim_feedforward", type=int, default=2048, help="feedforward layer size")
    parser.add_argument("--num_encoder_layers", type=int, default=6, help="Number of Transformer layers in the encoder")
    parser.add_argument("--num_decoder_layers", type=int, default=6, help="Number of Transformer layers in the decoder")
    parser.add_argument("--nhead", type=int, default=8, help="Number of Transformer heads")
    parser.add_argument("--dropout", type=float, default=0, help="Dropout")
    parser.add_argument("--attention_dropout", type=float, default=0, help="Dropout in the attention layer")
    # parser.add_argument("--continuous_embedding_model", type=str, default='ffn')
    parser.add_argument("--encoding_method", type=str, default="standard")
    parser.add_argument("--token_register_size", type=int, default=2)
    parser.add_argument("--positional_encoding", type=str, default="sinusoidal", choices=["sinusoidal", "embedding"])

    # vocab and tokenizer parameters
    parser.add_argument("--num_variables", type=int, default=2)
    parser.add_argument("--field", type=str, default="QQ", help="QQ or GFP with some integer P (e.g., GF7).")
    parser.add_argument("--max_coefficient", type=int, default=1000, help="The maximum coefficients")
    parser.add_argument("--max_degree", type=int, default=10, help="The maximum degree")
    parser.add_argument("--gaussian_encoding_upper_bound", type=int, default=10, help="For Gaussian embedding.")

    # training parameters
    parser.add_argument("--max_sequence_length", type=int, default=4096, help="Maximum sequences length")
    parser.add_argument("--num_batch", type=int, default=128, help="Number of sentences per batch")
    parser.add_argument("--test_batch_size", type=int, default=256, help="Number of sentences per batch")
    parser.add_argument(
        "--optimizer", type=str, default="adamw_apex_fused", help="Optimizer (SGD / RMSprop / Adam, etc.)"
    )
    parser.add_argument("--learning_rate", type=float, default=0.0001, help="learning rate (default 0.0001)")
    parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay (default 0)")
    parser.add_argument("--clip_grad_norm", type=float, default=5, help="Clip gradients norm (0 to disable)")
    parser.add_argument("--epochs", type=int, default=5, help="number of epochs")
    parser.add_argument("--max_steps_per_epoch", type=int, default=-1, help="Maximum epoch size")
    parser.add_argument("--warmup_ratio", type=float, default=0.0)
    parser.add_argument(
        "--stopping_criterion",
        type=str,
        default="",
        help="Stopping criterion, and number of non-increase before stopping the experiment",
    )
    parser.add_argument("--validation_metrics", type=str, default="", help="Validation metrics")
    parser.add_argument(
        "--accumulate_gradients",
        type=int,
        default=1,
        help="Accumulate model gradients over N iterations (N times larger batch sizes)",
    )
    parser.add_argument("--num_workers", type=int, default=8, help="Number of CPU workers for DataLoader")
    parser.add_argument("--deepspeed", type=str, default="")
    # environment parameters
    parser.add_argument("--env_name", type=str, default="char_sp", help="Environment name")

    parser.add_argument("--resume_from_checkpoint", action="store_true", default=False)

    # CPU / multi-gpu / multi-node
    parser.add_argument("--cpu", type=bool, default=False, help="Run on CPU")
    parser.add_argument("--local_rank", type=int, default=-1, help="Multi-GPU - Local rank")
    parser.add_argument("--master_port", type=int, default=-1, help="Master port (for multi-node SLURM jobs)")

    parser.add_argument("--dryrun", action="store_true", default=False)

    ## Dev
    parser.add_argument("--regression_weights", nargs="*", type=float)
    parser.add_argument("--continuous_coefficient", action="store_true", default=False)
    parser.add_argument("--continuous_exponent", action="store_true", default=False)
    parser.add_argument("--support_learning", action="store_true", default=False)
    parser.add_argument("--num_memory_tokens", type=int, default=16, help="Number of memory tokens")
    parser.add_argument("--use_memory_transformer", action="store_true", default=False)
    parser.add_argument("--use_fix_decoder_self_attn", action="store_true", default=False)
    parser.add_argument("--sparsity_lambda", type=float, default=0.0)
    parser.add_argument("--target_len", type=int, default=0, help="Target length for the permutation experiment")
    parser.add_argument("--permutation_str", type=str, default=str, help="Path to the permutation list file")

    return parser


def perms_to_tensor_list(perms_list_of_lists: List[List[int]], target_len: int) -> List[torch.Tensor]:
    tensor_list = []
    for p_list in perms_list_of_lists:
        matrix = torch.zeros((target_len, target_len), dtype=torch.float32)
        for i, p_i in enumerate(p_list):
            matrix[i, p_i] = 1.0
        tensor_list.append(matrix)
    return tensor_list


def main():
    import wandb

    # from transformers import TrainingArguments
    from trainer.trainer import CustomTrainer as Trainer
    from trainer.trainer import CustomTrainingArguments as TrainingArguments

    parser = get_parser()
    params = parser.parse_args()

    os.makedirs(params.save_path, exist_ok=True)

    ## Load data
    vocab = set_vocab(
        params.num_variables,
        "ZZ",
        params.max_coefficient,
        params.max_degree,
    )
    tokenizer = set_tokenizer(vocab, params.max_sequence_length)
    model = load_model(params, vocab=vocab, tokenizer=tokenizer)

    trainset = TextContinuationDataset(tokenizer, f"{params.data_path}.train", params.max_sequence_length)
    testset = TextContinuationDataset(tokenizer, f"{params.data_path}.test", params.max_sequence_length)
    # breakpoint()
    # permutations_list = perms_to_tensor_list([[3, 4, 5, 6, 7, 8, 9, 1, 0, 2]], params.target_len)
    # permutations_list = perms_to_tensor_list([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]], params.target_len)
    # permutations_list = perms_to_tensor_list([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]], params.target_len)
    # permutations_list = perms_to_tensor_list([[11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0]], params.target_len)
    if params.permutation_str is None:
        permutations_list = perms_to_tensor_list([[i for i in range(params.target_len)]], params.target_len)
    else:
        permutations_list = perms_to_tensor_list([[int(i) for i in params.permutation_str.split(",")]], params.target_len)
    # breakpoint()
    dc = PermutationExperimentDataCollator(
        tokenizer=tokenizer,
        permutations_list=permutations_list,
        input_prefix_len=params.target_len,
        apply_permutation_to_target_only=True,
        fixed_permutation_index=0,
        per_sample_permutation=False,
    )
    # dc = DynamicPrefixTargetPermutationCollator(
    #     tokenizer=tokenizer,
    #     permutations_list=permutations_list,
    #     target_len=params.target_len,
    #     fixed_permutation_index=0,
    #     per_sample_permutation=False,
    # )

    ## Save parameters
    with open(os.path.join(params.save_path, "params.yaml"), "w") as f:
        yaml.dump(vars(params), f)

    now = datetime.datetime.now(ZoneInfo("Asia/Tokyo"))
    datetime_str = now.strftime("%Y%m%d_%H%M%S")
    run_name = f"{params.exp_id}_{datetime_str}"

    ## Set up trainer
    trainer_config = TrainingArguments(
        output_dir=params.save_path,
        num_train_epochs=params.epochs,
        max_steps_per_epoch=params.max_steps_per_epoch,
        logging_steps=50,
        save_total_limit=1,
        dataloader_pin_memory=False,
        bf16=True,
        eval_steps=100,
        label_names=["labels"],
        remove_unused_columns=False,
        per_device_train_batch_size=params.num_batch // count_cuda_devices(),
        eval_strategy="steps",
        report_to="wandb",
        disable_tqdm=True,
        run_name=run_name,
    )

    limit_steps_callback = LimitStepsCallback(max_steps_per_epoch=params.max_steps_per_epoch)

    _compute_metrics = lambda x: compute_metrics(x, ignore_index=tokenizer.pad_token_id)
    trainer = Trainer(
        args=trainer_config,
        model=model,
        train_dataset=trainset,
        eval_dataset=testset,
        data_collator=dc,
        callbacks=[limit_steps_callback],
    )
    # breakpoint()

    ## Run training
    wandb.init(project=params.exp_name, name=run_name, config=trainer_config)
    s = time()
    train_result = trainer.train()
    print(f"training time: [{time()-s:.1f} sec]")

    ## Evaluate
    # acc, df_gen = trainer.evaluate_test_greedy(tokenizer=tokenizer)
    acc, df_gen = trainer.evaluate_gpt2(tokenizer)
    df_gen.to_csv(os.path.join(params.save_path, "generated.csv"), index=False)

    # trainer.save_model()

    metrics = train_result.metrics
    dataset_metrics = trainer.evaluate(metric_key_prefix="test")
    metrics.update(dataset_metrics)
    metrics["test/accuracy"] = acc

    trainer.save_metrics("all", metrics)
    wandb.finish()


if __name__ == "__main__":
    main()
