import argparse
import numpy as np
import yaml, os
import torch
from time import time
import datetime
from zoneinfo import ZoneInfo
import json  # Added for saving metrics and printing accuracy
from utils.permutation_utils import (
    make_perm_family,
    generate_all_permutation_matrices,
    generate_random_permutation,
)  # Added for permutation

# from trainer import Trainer
from loader.data import _load_data
from loader.data_collator import PermutationExperimentDataCollator
from loader.model import load_model
from trainer.trainer_utils import compute_metrics, preprocess_logits_for_metrics, LimitStepsCallback
from utils.utils import count_cuda_devices
import sys
from transformers import AutoTokenizer
from torch.utils.data import Dataset


import warnings
import logging

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

warnings.filterwarnings("ignore", message="Was asked to gather along dimension 0")
warnings.filterwarnings("ignore", message="The PyTorch API of nested tensors is in prototype stage")


import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

# 乱数シードの固定
torch.use_deterministic_algorithms(True)
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)


# Simplified Dataset for text files (one sequence per line)
class TextContinuationDataset(Dataset):
    def __init__(
        self, tokenizer: AutoTokenizer, file_path: str, max_length: int, data_has_colon_separator: bool = True
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data_has_colon_separator = data_has_colon_separator
        self.examples = []  # Stores the full text strings for tokenization in __getitem__
        logger.info(f"Loading data from {file_path}")
        try:
            with open(file_path, "r", encoding="utf-8") as f:
                for line_num, line in enumerate(f):
                    line = line.strip()
                    if not line:
                        continue

                    if self.data_has_colon_separator:
                        parts = line.split(":", 1)
                        if len(parts) == 2:
                            input_part = parts[0].strip()
                            target_part = parts[1].strip()
                            # For Causal LM, the full sequence is typically "INPUT : TARGET_PERMUTABLE <EOS>"
                            # The `input_prefix_len` for the collator should be the tokenized length of "INPUT : "
                            full_text = f"{input_part} : {target_part}"
                            self.examples.append(full_text)
                        else:
                            logger.warning(
                                f"Line {line_num+1} in {file_path} does not contain ':' separator. Treating as full line: '{line}'"
                            )
                            # If no colon, treat the whole line as one sequence. Permutation might be applied to a segment from start if prefix_len is non-zero.
                            self.examples.append(line)
                    else:
                        self.examples.append(line)
            logger.info(f"Loaded {len(self.examples)} examples from {file_path}.")
        except FileNotFoundError:
            logger.error(f"Data file not found: {file_path}")
            raise
        except Exception as e:
            logger.error(f"Error reading data file {file_path}: {e}")
            raise

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        full_text = self.examples[idx]
        # The collator handles tokenization now, so dataset returns raw text.
        if " : " in full_text:
            full_text = full_text.replace(
                " : ", " "
            )  # This might simplify things for tokenizer, but collator should handle specific formatting if needed.
        return {
            "text": full_text,
        }


def get_parser():
    """
    Generate a parameters parser.
    """
    # parse parameters
    parser = argparse.ArgumentParser(description="Language transfer with ranked permutations")  # Modified description

    # main parameters
    parser.add_argument("--data_path", type=str, default="./data/data_sum", help="Experiment dump path")
    parser.add_argument("--data_encoding", type=str, default="infix")
    parser.add_argument("--save_path", type=str, default="./dumped", help="Experiment dump path")
    parser.add_argument("--save_periodic", type=int, default=0, help="Save the model periodically (0 to disable)")
    parser.add_argument("--exp_name", type=str, default="debug", help="Experiment name")
    parser.add_argument("--exp_id", type=str, default="", help="Experiment ID")
    parser.add_argument("--task", type=str, default="sum", help="Task name")

    # float16 / AMP API
    parser.add_argument("--fp16", type=bool, default=True, help="Run model with float16")
    parser.add_argument(
        "--amp",
        type=int,
        default=2,
        help="Use AMP wrapper for float16 / distributed / gradient accumulation. Level of optimization. -1 to disable.",
    )

    # model parameters
    parser.add_argument("--model", type=str, default="gpt2", help="Embedding layer size")
    parser.add_argument("--d_model", type=int, default=512, help="Embedding layer size")
    parser.add_argument("--dim_feedforward", type=int, default=2048, help="feedforward layer size")
    parser.add_argument("--num_encoder_layers", type=int, default=6, help="Number of Transformer layers in the encoder")
    parser.add_argument("--num_decoder_layers", type=int, default=6, help="Number of Transformer layers in the decoder")
    parser.add_argument("--nhead", type=int, default=8, help="Number of Transformer heads")
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--attention_dropout", type=float, default=0.1, help="Dropout in the attention layer")
    parser.add_argument("--encoding_method", type=str, default="standard")
    parser.add_argument("--token_register_size", type=int, default=2)
    parser.add_argument("--positional_encoding", type=str, default="sinusoidal", choices=["sinusoidal", "embedding"])

    # vocab and tokenizer parameters
    parser.add_argument("--num_variables", type=int, default=0)
    parser.add_argument("--field", type=str, default="ZZ", help="QQ or GFP with some integer P (e.g., GF7).")
    parser.add_argument("--max_coefficient", type=int, default=500, help="The maximum coefficients")
    parser.add_argument("--max_degree", type=int, default=0, help="The maximum degree")
    parser.add_argument("--gaussian_encoding_upper_bound", type=int, default=10, help="For Gaussian embedding.")

    # training parameters
    parser.add_argument("--max_sequence_length", type=int, default=128, help="Maximum sequences length")
    parser.add_argument("--num_batch", type=int, default=128, help="Number of sentences per batch")
    parser.add_argument("--test_batch_size", type=int, default=256, help="Number of sentences per batch")
    parser.add_argument(
        "--optimizer", type=str, default="adamw_apex_fused", help="Optimizer (SGD / RMSprop / Adam, etc.)"
    )
    parser.add_argument("--learning_rate", type=float, default=0.0001, help="learning rate (default 0.0001)")
    parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay (default 0)")
    parser.add_argument("--clip_grad_norm", type=float, default=5, help="Clip gradients norm (0 to disable)")
    parser.add_argument("--epochs", type=int, default=5, help="number of epochs")
    parser.add_argument("--max_steps_per_epoch", type=int, default=-1, help="Maximum epoch size")
    parser.add_argument("--warmup_ratio", type=float, default=0.0)
    parser.add_argument(
        "--stopping_criterion",
        type=str,
        default="",
        help="Stopping criterion, and number of non-increase before stopping the experiment",
    )
    parser.add_argument("--validation_metrics", type=str, default="", help="Validation metrics")
    parser.add_argument(
        "--accumulate_gradients",
        type=int,
        default=1,
        help="Accumulate model gradients over N iterations (N times larger batch sizes)",
    )
    parser.add_argument("--num_workers", type=int, default=8, help="Number of CPU workers for DataLoader")
    parser.add_argument("--deepspeed", type=str, default="")
    parser.add_argument("--env_name", type=str, default="char_sp", help="Environment name")
    parser.add_argument("--resume_from_checkpoint", action="store_true", default=False)

    # CPU / multi-gpu / multi-node
    parser.add_argument("--cpu", type=bool, default=False, help="Run on CPU")
    parser.add_argument("--local_rank", type=int, default=-1, help="Multi-GPU - Local rank")
    parser.add_argument("--master_port", type=int, default=-1, help="Master port (for multi-node SLURM jobs)")

    parser.add_argument("--dryrun", action="store_true", default=False)

    ## Dev
    parser.add_argument("--regression_weights", nargs="*", type=float)
    parser.add_argument("--continuous_coefficient", action="store_true", default=False)
    parser.add_argument("--continuous_exponent", action="store_true", default=False)
    parser.add_argument("--support_learning", action="store_true", default=False)
    parser.add_argument("--num_memory_tokens", type=int, default=16, help="Number of memory tokens")
    parser.add_argument("--use_memory_transformer", action="store_true", default=False)
    parser.add_argument("--use_fix_decoder_self_attn", action="store_true", default=False)
    parser.add_argument("--sparsity_lambda", type=float, default=0.0)

    # Added arguments for permutation experiment
    parser.add_argument(
        "--permutation_id",
        type=int,
        default=None,
        help="ID of the permutation to use for the target sequence. If None, original order is used.",
    )
    parser.add_argument(
        "--permutation_num",
        type=int,
        default=None,
        help="Number of permutations to generate. If None, all permutations are generated.",
    )
    parser.add_argument(
        "--permutation_type",
        type=str,
        default="family",
        choices=["all", "family", "random", "random_one"],
        help="Type of permutation to generate. 'all' for all permutations, 'family' for a family of permutations, 'random' for random permutations, 'random_one' for a single random permutation.",
    )
    parser.add_argument(
        "--target_len",
        type=int,
        default=50,
        help="Length of the target sequence for generating permutations. This should match the length of the target sequences in your dataset.",
    )

    return parser


def main():
    import wandb

    from trainer.trainer import CustomTrainer as Trainer
    from trainer.trainer import CustomTrainingArguments as TrainingArguments

    parser = get_parser()
    params = parser.parse_args()

    os.makedirs(params.save_path, exist_ok=True)

    ## Load data
    # trainset = _load_data(f"{params.data_path}.train")
    # testset = _load_data(f"{params.data_path}.test")
    

    # Permute target data if permutation_id is provided
    if params.permutation_id is not None:
        if params.permutation_num is None:
            print("Error: permutation_num must be provided if permutation_id is set.")
            sys.exit(1)

        # Determine L for make_perm_family.
        # This L should be the length of the sequence that the permutation acts upon.
        # For many seq-to-seq tasks, this is the target sequence length.
        # If data loader or collator pads sequences, L should be the padded length.
        # If model internally handles permutations for variable lengths, this logic might differ.
        # Let's try to infer L from trainset.target, assuming it's representative.
        # Fallback to max_sequence_length if target is not yet available or structure is unknown.

        # Try to get sequence length from target.
        # This logic for L_for_perm needs to be robust.
        L_for_perm = params.target_len
        try:
            if params.permutation_type == "all":
                all_permutations = generate_all_permutation_matrices(L_for_perm)

            elif params.permutation_type == "family":
                import math

                n_exp = int(math.log2(params.permutation_num))
                all_permutations = make_perm_family(L=L_for_perm, n_exponent=n_exp)
            elif params.permutation_type == "random":
                all_permutations = generate_random_permutation(L_for_perm, num_samples=params.permutation_num)
            elif params.permutation_type == "random_one":
                all_permutations = generate_random_permutation(L_for_perm, num_samples=params.permutation_num)
                family_perms = make_perm_family(L=L_for_perm, n_exponent=3)
                all_permutations[0] = family_perms[2]
            else:
                print(f"Error: Unknown permutation type {params.permutation_type}.")
                sys.exit(1)
        except Exception as e:
            print(f"Error generating permutations: {e}")
            sys.exit(1)

        if params.permutation_id >= len(all_permutations):
            print(
                f"Error: permutation_id {params.permutation_id} is out of range for {len(all_permutations)} permutations generated with L={L_for_perm}, num={params.permutation_num}."
            )
            sys.exit(1)

    ## Load model

    from data.tokenizers import set_tokenizer, set_vocab

    vocab = set_vocab(
        params.num_variables,
        field=params.field,
        max_coeff=params.max_coefficient,
        max_degree=params.max_degree,
        continuous_coefficient=False,
        continuous_exponent=False,
    )
    tokenizer = set_tokenizer(vocab)
    trainset = TextContinuationDataset(tokenizer=tokenizer, file_path=f"{params.data_path}.train", max_length=params.max_sequence_length)
    testset = TextContinuationDataset(tokenizer=tokenizer, file_path=f"{params.data_path}.test", max_length=params.max_sequence_length)
    model = load_model(params, vocab=vocab, tokenizer=tokenizer)
    dc = PermutationExperimentDataCollator(
        tokenizer=tokenizer,
        permutations_list=all_permutations,
        input_prefix_len=L_for_perm,
        apply_permutation_to_target_only=True,
        fixed_permutation_index=params.permutation_id,
        per_sample_permutation=False,
    )
    # breakpoint()
    ## Save parameters
    with open(os.path.join(params.save_path, "params.yaml"), "w") as f:
        yaml.dump(vars(params), f)

    now = datetime.datetime.now(ZoneInfo("Asia/Tokyo"))
    datetime_str = now.strftime("%Y%m%d_%H%M%S")
    # Modify run_name to include permutation_id if present
    run_name_suffix = f"_perm{params.permutation_id}" if params.permutation_id is not None else ""
    run_name = f"{params.exp_id}_{datetime_str}{run_name_suffix}"

    ## Set up trainer
    trainer_config = TrainingArguments(
        output_dir=params.save_path,
        num_train_epochs=params.epochs,
        max_steps_per_epoch=params.max_steps_per_epoch,
        logging_steps=50,
        save_total_limit=1,
        dataloader_pin_memory=False,
        bf16=True,  # Ensure this is compatible with your hardware
        eval_steps=100,
        label_names=["labels"],
        remove_unused_columns=False,
        per_device_train_batch_size=(
            params.num_batch // count_cuda_devices() if count_cuda_devices() > 0 else params.num_batch
        ),
        per_device_eval_batch_size=(
            params.test_batch_size // count_cuda_devices() if count_cuda_devices() > 0 else params.test_batch_size
        ),
        eval_strategy="steps",
        report_to="wandb",
        disable_tqdm=True,
        run_name=run_name,
    )

    limit_steps_callback = LimitStepsCallback(max_steps_per_epoch=params.max_steps_per_epoch)

    # Ensure tokenizer is available for compute_metrics if needed
    # The original main.py had _compute_metrics commented out.
    # If your compute_metrics relies on tokenizer.pad_token_id, ensure tokenizer is defined.
    # _compute_metrics = lambda x: compute_metrics(x, ignore_index=tokenizer.pad_token_id if tokenizer else -100)

    trainer = Trainer(
        args=trainer_config,
        model=model,
        train_dataset=trainset,
        eval_dataset=testset,
        data_collator=dc,
        # compute_metrics=_compute_metrics, # Uncomment if needed
        # preprocess_logits_for_metrics=preprocess_logits_for_metrics, # Uncomment if needed
        callbacks=[limit_steps_callback],
    )

    ## Run training
    if params.local_rank <= 0:  # Ensure wandb.init is called only on main process for DDP
        wandb.init(project=params.exp_name, name=run_name, config=vars(params))  # Log all params

    s = time()
    train_result = trainer.train(resume_from_checkpoint=params.resume_from_checkpoint)
    print(f"training time: [{time()-s:.1f} sec]")

    ## Evaluate
    # The original script had a specific evaluation for gpt2 vs other models
    # Ensure the correct evaluation method is called.
    # Also ensure tokenizer is passed if needed by the evaluation function.
    acc = "N/A"  # Default accuracy
    df_gen = None

    if params.model == "gpt2":
        if tokenizer is None:
            print("Error: Tokenizer is required for GPT-2 evaluation but it's None.")
            # Potentially load tokenizer if path known or handle error
        else:
            acc, df_gen = trainer.evaluate_gpt2(tokenizer)
    elif hasattr(trainer, "evaluate_test_greedy"):  # Fallback or other model types
        if tokenizer is None:
            print(
                "Warning: Tokenizer is None, evaluate_test_greedy might fail or produce incorrect results if it relies on it."
            )
        acc, df_gen = trainer.evaluate_test_greedy(tokenizer=tokenizer)  # Pass tokenizer if available/needed
    else:
        print(
            "Warning: No specific evaluation method found for model type or evaluate_test_greedy not available. Using trainer.evaluate()."
        )
        # Standard evaluate might not return accuracy in the same way.
        # Metrics from trainer.evaluate are usually more comprehensive.

    if df_gen is not None:
        df_gen.to_csv(
            os.path.join(
                params.save_path,
                f"generated_perm{params.permutation_id if params.permutation_id is not None else 'orig'}.csv",
            ),
            index=False,
        )

    metrics = train_result.metrics if train_result else {}
    # Always run evaluate to get latest metrics on test set
    dataset_metrics = trainer.evaluate(metric_key_prefix="test")
    metrics.update(dataset_metrics)  # trainer.evaluate often returns a dict like {'epoch': ..., 'test_loss': ...}

    # Ensure 'acc' is correctly updated if trainer.evaluate() provides a primary accuracy metric.
    # The key for accuracy might be different, e.g., 'test_accuracy' or 'test_eval_accuracy'.
    # Prioritize `acc` from specific eval methods if available.
    if acc == "N/A" and "test_accuracy" in metrics:  # common key from evaluate
        acc = metrics["test_accuracy"]
    elif acc == "N/A" and "eval_accuracy" in metrics:  # another common key
        acc = metrics["eval_accuracy"]

    metrics["test/final_accuracy_reported"] = acc  # Store the accuracy value we intend to report

    if params.local_rank <= 0:  # Only main process saves metrics and finishes wandb
        trainer.save_metrics("all", metrics)
        metrics_save_path = os.path.join(
            params.save_path,
            f"all_metrics_perm{params.permutation_id if params.permutation_id is not None else 'orig'}.json",
        )
        with open(metrics_save_path, "w") as f:
            json.dump(metrics, f, indent=4, default=str)  # Use default=str for non-serializable items like numpy floats
        print(f"Saved all metrics to {metrics_save_path}")

        wandb.finish()

    # Print the final accuracy to stdout for shell script to capture
    # Ensure this is printed by all ranks or only rank 0, depending on how shell script collects.
    # For simplicity, let rank 0 print it.
    if params.local_rank <= 0 or params.local_rank == -1:  # Also for non-DDP case (local_rank == -1)
        final_accuracy_to_print = acc
        if isinstance(acc, torch.Tensor):
            final_accuracy_to_print = acc.item()
        elif hasattr(acc, "item"):  # For numpy types
            final_accuracy_to_print = acc.item()

        print(f"Final Test Accuracy: {final_accuracy_to_print}")


if __name__ == "__main__":
    main()
