import argparse
import numpy as np
import yaml, os
import torch
from time import time
import datetime
from zoneinfo import ZoneInfo
from dataclasses import asdict

# Import custom classes
from loader.models import MyGPT2, NselectGPT2
from loader.data import _load_data, SimpleDataCollator, GPTDataCollator
from trainer import MyTrainer, MyTrainingArguments, LogPermutationCallback
from trainer.trainer_utils import (
    compute_metrics,
    preprocess_logits_for_metrics,
    LimitStepsCallback,
)
from utils.utils import count_cuda_devices
from transformers import GPT2Config, HfArgumentParser
from data.tokenizers import set_vocab, set_tokenizer

# Ignore specific warnings
import warnings

warnings.filterwarnings("ignore", message="Was asked to gather along dimension 0")
warnings.filterwarnings("ignore", message="The PyTorch API of nested tensors is in prototype stage")


import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

# 乱数シードの固定
torch.use_deterministic_algorithms(True)
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)


def main():
    import wandb

    # Use HfArgumentParser to parse all arguments defined in MyTrainingArguments
    hf_parser = HfArgumentParser((MyTrainingArguments,))
    # Parse command-line arguments. Pass sys.argv explicitly if needed, otherwise it defaults.
    training_args = hf_parser.parse_args_into_dataclasses()[0]

    # Calculate per_device batch sizes if num_batch/test_batch_size were provided
    gpus = count_cuda_devices()
    if training_args.num_batch is not None:
        if training_args.per_device_train_batch_size is None:
            print(
                f"Calculating per_device_train_batch_size from num_batch ({training_args.num_batch}) and gpu count ({gpus})"
            )
            training_args.per_device_train_batch_size = (
                training_args.num_batch // gpus if gpus > 0 else training_args.num_batch
            )
        else:
            print("Both num_batch and per_device_train_batch_size specified. Using per_device_train_batch_size.")
    if training_args.test_batch_size is not None:
        if training_args.per_device_eval_batch_size is None:
            print(
                f"Calculating per_device_eval_batch_size from test_batch_size ({training_args.test_batch_size}) and gpu count ({gpus})"
            )
            training_args.per_device_eval_batch_size = (
                training_args.test_batch_size // gpus if gpus > 0 else training_args.test_batch_size
            )
        else:
            print("Both test_batch_size and per_device_eval_batch_size specified. Using per_device_eval_batch_size.")

    # Ensure output_dir exists
    os.makedirs(training_args.output_dir, exist_ok=True)

    # --- Load Data ---
    print(f"Loading training data from: {training_args.data_path}.train")
    train_dataset = _load_data(f"{training_args.data_path}.train")
    print(f"Loading evaluation data from: {training_args.data_path}.test")
    eval_dataset = _load_data(f"{training_args.data_path}.test")
    print(f"Train dataset size: {len(train_dataset)}, Eval dataset size: {len(eval_dataset)}")
    training_args.dry_run = False  # Default to False unless specified in command line
    if training_args.dry_run:  # Use standard dry_run field
        print("Dry run mode enabled. Subsetting data and adjusting training steps.")
        # Create subset datasets for dry run
        # Note: This way of subsetting might not be ideal for all dataset types
        # Consider using torch.utils.data.Subset for more robust subsetting
        try:
            train_indices = list(range(min(10000, len(train_dataset))))
            eval_indices = list(range(min(100, len(eval_dataset))))
            # Assuming train_dataset and eval_dataset are map-style datasets (support indexing)
            train_dataset = torch.utils.data.Subset(train_dataset, train_indices)
            eval_dataset = torch.utils.data.Subset(eval_dataset, eval_indices)
            print(f"Dry run dataset sizes - Train: {len(train_dataset)}, Eval: {len(eval_dataset)}")
        except TypeError as e:
            print(f"Failed to create dataset subsets for dry run, likely due to dataset type: {e}")
            # Fallback or raise error depending on requirements
            # For now, we'll proceed but the dataset size won't be reduced.
            pass

        # Modify training_args directly for dry run adjustments
        # Use max_steps for dry run control
        training_args.max_steps = 20  # Total steps for dry run (10 for stage1, 10 for stage2)
        training_args.num_train_epochs = None  # Disable epoch-based training for dry run
        # Adjust output dir based on the original output_dir
        original_output_dir = training_args.output_dir.replace(
            "/dryrun_permutation", ""
        )  # Attempt to remove suffix if resuming dryrun
        training_args.output_dir = os.path.join(os.path.dirname(original_output_dir), "dryrun_permutation")
        # Adjust logging/eval/save steps in training_args
        training_args.logging_steps = 1
        training_args.evaluation_strategy = "steps"  # Ensure evaluation happens during dry run
        training_args.eval_steps = 5
        training_args.save_steps = 5
        training_args.exp_name = "dryrun_permutation"  # Override exp_name for clarity

        os.makedirs(training_args.output_dir, exist_ok=True)

    # --- Load tokenizer and model config ---
    # Create vocab and tokenizer from scratch using training_args
    vocab = set_vocab(
        num_vars=0,
        field="ZZ",
        max_coeff=1000,
        max_degree=0,
        continuous_coefficient=False,
        continuous_exponent=False,
    )
    tokenizer = set_tokenizer(vocab)

    # --- Save tokenizer (consistent with main.py) ---
    tokenizer.save_pretrained(os.path.join(training_args.output_dir, "tokenizer.json"))

    # Load model config (still use AutoConfig for base GPT-2 structure)
    # config = AutoConfig.from_pretrained(training_args.model_name_or_path)
    config = GPT2Config(
        vocab_size=len(tokenizer),
        n_embd=512,
        n_layer=1,
        n_head=1,
        attn_pdrop=0.0,
        n_positions=training_args.max_sequence_length,
        max_posion_embeddings=training_args.max_sequence_length,
        n_inner=2048,
        pad_token_id=tokenizer.pad_token_id if hasattr(tokenizer, "pad_token_id") else None,
        bos_token_id=tokenizer.bos_token_id if hasattr(tokenizer, "bos_token_id") else None,
        eos_token_id=tokenizer.eos_token_id if hasattr(tokenizer, "eos_token_id") else None,
    )

    # Update config based on tokenizer and training_args
    if tokenizer.pad_token is None:
        # Handle potential differences in special tokens if set_tokenizer doesn't add them
        # For now, assume set_tokenizer handles pad/bos/eos or uses defaults
        if hasattr(tokenizer, "pad_token_id") and tokenizer.pad_token_id is not None:
            config.pad_token_id = tokenizer.pad_token_id
        else:
            # Fallback or raise error - assuming eos is a safe default for padding
            config.pad_token_id = tokenizer.eos_token_id if hasattr(tokenizer, "eos_token_id") else 1
            print(f"Warning: Tokenizer pad_token not set. Using eos_token_id ({config.pad_token_id}) as pad_token_id.")

    ## Load MyGPT2 model
    # breakpoint()
    # model = MyGPT2.from_pretrained(training_args.model_name_or_path, config=config) # Cannot load from pretrained with scratch tokenizer
    # model = MyGPT2(config)  # Initialize model with the modified config
    model = NselectGPT2(config)  # Initialize model with the modified config
    model.resize_token_embeddings(len(tokenizer))  # Resize embeddings to match scratch tokenizer

    ## Data Collator (use training_args)
    dc = GPTDataCollator(tokenizer)
    label_names = ["labels"]  # Standard for Causal LM

    ## Save parameters (save the parsed training_args)
    all_params = asdict(training_args)  # Convert dataclass to dict
    with open(os.path.join(training_args.output_dir, "params.yaml"), "w") as f:
        yaml.dump(all_params, f)

    now = datetime.datetime.now(ZoneInfo("Asia/Tokyo"))
    datetime_str = now.strftime("%Y%m%d_%H%M%S")
    # Generate run name using training_args.exp_name/exp_id if run_name not set
    run_name = f"{training_args.exp_id}_{datetime_str}"
    training_args.run_name = run_name  # Set it in training_args

    # Ensure only latest checkpoint kept similar to main.py
    if training_args.save_total_limit is None or training_args.save_total_limit > 1:
        training_args.save_total_limit = 1

    ## Set up trainer
    trainer = MyTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=dc,
        callbacks=[LogPermutationCallback()],
    )
    # breakpoint()

    ## Run training
    wandb.init(project=training_args.exp_name, name=run_name, config=all_params)
    s = time()
    # Call train normally; the callback handles stage switching
    train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
    print(f"Training time: [{time()-s:.1f} sec]")

    ## Evaluate
    eval_metrics = trainer.evaluate()
    print(f"Evaluation results: {eval_metrics}")

    # Save final model
    trainer.save_model()
    trainer.save_state()

    metrics = train_result.metrics
    metrics.update(eval_metrics)

    # Log final metrics
    if trainer.is_world_process_zero():
        print("Logging final metrics...")
        wandb.log(metrics)
        try:
            with open(os.path.join(training_args.output_dir, "all_results.yaml"), "w") as f:
                yaml.dump(metrics, f)
        except Exception as e:
            print(f"Failed to save metrics to YAML: {e}")

    wandb.finish()


if __name__ == "__main__":
    main()
