import argparse
import os
import torch
import math
import json
import wandb
from transformers import GPT2LMHeadModel, GPT2Config, AutoTokenizer, HfArgumentParser, TrainingArguments
from torch.utils.data import Dataset
import logging
from dataclasses import dataclass, field
from typing import Optional

# Custom imports
from utils.permutation_utils import get_permutations, generate_all_permutation_matrices, generate_random_permutation
from loader.data_collator import PermutationExperimentDataCollator
from loader.data import _load_data
from trainer.permutation_loss_logging_trainer import (
    PermutationLossLoggingTrainer,
    PermutationLogArgs,
    PermutationLossLoggingTrainingArguments,
)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


@dataclass
class ScriptArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    # Non-default arguments first
    dataset_name: str = field(metadata={"help": "Name of the dataset (e.g., relu, square_mod19, index)"})
    dataset_path_prefix: str = field(
        metadata={"help": "Path prefix to the .train and .test data files (e.g., data/relu_n50/relu_n50)"}
    )
    target_len: int = field(
        metadata={
            "help": "Length of the target sequence part to be permuted (after tokenization). Used for generating permutations."
        }
    )
    permutation_select_num: int = field(
        metadata={
            "help": "Number of permutations to generate and use (must be a power of 2, or 1).",
            "choices": [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
        }
    )
    permutation_type: str = field(
        default="family",
        metadata={
            "help": "Type of permutations to generate. Options: 'all' (all permutations), 'random' (random permutation), 'family' (permutation family)."
        },
    )

    # Default arguments follow
    input_prefix_len: int = field(
        default=20,
        metadata={
            "help": "Length of the prefix part of the input sequence (in tokens). This is used to determine where the target part starts."
        },
    )
    gpt2_n_head: int = field(default=1, metadata={"help": "Number of attention heads for GPT2 model."})
    gpt2_n_layer: int = field(default=1, metadata={"help": "Number of layers for GPT2 model."})
    gpt2_n_embd: int = field(default=512, metadata={"help": "Embedding dimension for GPT2 model."})
    max_seq_length: int = field(default=128, metadata={"help": "Maximum sequence length for tokenizer and model."})
    tokenizer_name: str = field(default="gpt2", metadata={"help": "Tokenizer name or path."})

    # Wandb arguments
    wandb_project: Optional[str] = field(
        default="permutation_loss_analysis", metadata={"help": "Weights & Biases project name."}
    )
    wandb_entity: Optional[str] = field(default=None, metadata={"help": "Weights & Biases entity name."})
    wandb_run_name: Optional[str] = field(
        default=None, metadata={"help": "Weights & Biases run name. Defaults to a generated name."}
    )
    
    m_param: Optional[int] = field(
        default=6,
        metadata={
            "help": "Number of parameters for the model. Used for logging and analysis purposes."
        }
    )
    k_max_layers: Optional[int] = field(
        default=2,
        metadata={
            "help": "Maximum number of layers for the model. Used for logging and analysis purposes."
        }
    )


# Simplified Dataset for text files (one sequence per line)
class TextContinuationDataset(Dataset):
    def __init__(
        self, tokenizer: AutoTokenizer, file_path: str, max_length: int, data_has_colon_separator: bool = True
    ):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data_has_colon_separator = data_has_colon_separator
        self.examples = []  # Stores the full text strings for tokenization in __getitem__
        logger.info(f"Loading data from {file_path}")
        try:
            with open(file_path, "r", encoding="utf-8") as f:
                for line_num, line in enumerate(f):
                    line = line.strip()
                    if not line:
                        continue

                    if self.data_has_colon_separator:
                        parts = line.split(":", 1)
                        if len(parts) == 2:
                            input_part = parts[0].strip()
                            target_part = parts[1].strip()
                            # For Causal LM, the full sequence is typically "INPUT : TARGET_PERMUTABLE <EOS>"
                            # The `input_prefix_len` for the collator should be the tokenized length of "INPUT : "
                            full_text = f"{input_part} : {target_part}"
                            self.examples.append(full_text)
                        else:
                            logger.warning(
                                f"Line {line_num+1} in {file_path} does not contain ':' separator. Treating as full line: '{line}'"
                            )
                            # If no colon, treat the whole line as one sequence. Permutation might be applied to a segment from start if prefix_len is non-zero.
                            self.examples.append(line)
                    else:
                        self.examples.append(line)
            logger.info(f"Loaded {len(self.examples)} examples from {file_path}.")
        except FileNotFoundError:
            logger.error(f"Data file not found: {file_path}")
            raise
        except Exception as e:
            logger.error(f"Error reading data file {file_path}: {e}")
            raise

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        full_text = self.examples[idx]
        # The collator handles tokenization now, so dataset returns raw text.
        if " : " in full_text:
            full_text = full_text.replace(
                " : ", " "
            )  # This might simplify things for tokenizer, but collator should handle specific formatting if needed.
        return {
            "text": full_text,
        }


def main():
    parser = HfArgumentParser((PermutationLossLoggingTrainingArguments, PermutationLogArgs, ScriptArguments))
    training_args, perm_log_args_unused, script_args = parser.parse_args_into_dataclasses()

    if script_args.wandb_project:
        wandb.init(
            project=script_args.wandb_project,
            entity=script_args.wandb_entity,
            name=script_args.wandb_run_name,
            config=vars(script_args),
        )
        if "wandb" not in training_args.report_to:
            if training_args.report_to is None or training_args.report_to == "none":
                training_args.report_to = ["wandb"]
            elif isinstance(training_args.report_to, list):
                training_args.report_to.append("wandb")
            else:
                training_args.report_to = [training_args.report_to, "wandb"]
        logger.info(f"Initialized wandb for project '{script_args.wandb_project}'")
    else:
        logger.info("wandb_project not set, wandb logging disabled.")
        training_args.report_to = (
            [r for r in training_args.report_to if r != "wandb"] if training_args.report_to else []
        )

    # 1. Initialize Tokenizer
    from data.tokenizers import set_tokenizer, set_vocab

    vocab = set_vocab(
        0, field="ZZ", max_coeff=500, max_degree=1, continuous_coefficient=False, continuous_exponent=False
    )
    tokenizer = set_tokenizer(vocab)
    if not hasattr(tokenizer, "unk_token") or tokenizer.unk_token is None:
        tokenizer.add_special_tokens({"unk_token": "[UNK]"})
    if tokenizer.pad_token is None:
        if tokenizer.eos_token is not None:
            tokenizer.pad_token = tokenizer.eos_token
        else:
            tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    os.makedirs(training_args.output_dir, exist_ok=True)
    tokenizer_save_path = os.path.join(training_args.output_dir, "tokenizer_config_for_exp")
    tokenizer.save_pretrained(tokenizer_save_path)
    logger.info(f"Tokenizer config saved to {tokenizer_save_path}")
    logger.info(f"Using input_prefix_len (token count): {script_args.input_prefix_len}")

    # 2. Load Datasets
    train_file_path = f"{script_args.dataset_path_prefix}.train"
    eval_file_path = f"{script_args.dataset_path_prefix}.test"
    train_dataset = TextContinuationDataset(
        tokenizer, train_file_path, max_length=script_args.max_seq_length, data_has_colon_separator=True
    )
    eval_dataset = TextContinuationDataset(
        tokenizer, eval_file_path, max_length=script_args.max_seq_length, data_has_colon_separator=True
    )
    if not train_dataset or len(train_dataset) == 0:
        logger.error(f"Training dataset is empty. Exiting.")
        return
    if not eval_dataset or len(eval_dataset) == 0:
        logger.warning(
            f"Evaluation dataset is empty. Final permutation evaluation might not run or produce meaningful results."
        )

    # 3. Generate Permutations
    if script_args.permutation_type == "all":
        permutations = generate_all_permutation_matrices(script_args.target_len)
    
    elif script_args.permutation_type == "random":
        permutations = generate_random_permutation(
            N=script_args.target_len, num_samples=script_args.permutation_select_num)
    elif script_args.permutation_type == "random_one":
        # Generate a single random permutation matrix
        permutations = generate_random_permutation(
            N=script_args.target_len, num_samples=script_args.permutation_select_num,
        )
        one_perm = get_permutations(
            target_len=script_args.target_len, permutation_select_num=2
        )
        # select_num = int(script_args.permutation_select_num / 2) - 1
        # select_num = 5
        permutations[0] = one_perm[0]
        # breakpoint()
        
    elif script_args.permutation_type == "family":
        permutations = get_permutations(
            target_len=script_args.target_len, permutation_select_num=script_args.permutation_select_num
        )
    else:
        raise ValueError(f"Unsupported permutation type: {script_args.permutation_type}")
    
    logger.info(f"Generated {permutations.shape[0]} permutations for target length {script_args.target_len}.")

    # 4. Initialize Data Collator for Training (with per-sample permutation)
    train_data_collator = PermutationExperimentDataCollator(
        tokenizer=tokenizer,
        permutations_list=list(permutations),
        input_prefix_len=script_args.input_prefix_len,
        apply_permutation_to_target_only=True,
        per_sample_permutation=True,  # Apply different permutation to each sample in batch during training
    )

    # For evaluation, we'll create separate collators with fixed permutation indices

    # 5. Initialize Model (GPT-2)
    config = GPT2Config(
        vocab_size=len(tokenizer),
        n_positions=script_args.max_seq_length,
        n_ctx=script_args.max_seq_length,
        n_embd=script_args.gpt2_n_embd,
        n_layer=script_args.gpt2_n_layer,
        n_head=script_args.gpt2_n_head,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )
    model = GPT2LMHeadModel(config)
    if model.config.vocab_size != len(tokenizer):
        model.resize_token_embeddings(len(tokenizer))

    # 6. Initialize Trainer
    training_args.dataloader_pin_memory = False  # Enable pin_memory for faster data loading
    trainer = PermutationLossLoggingTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=train_data_collator,
    )

    # 7. Start Training
    if training_args.do_train:
        logger.info("Starting training...")
        try:
            train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
            trainer.save_model()  # Save the final model
            trainer.log_metrics("train", train_result.metrics)
            trainer.save_metrics("train", train_result.metrics)
            trainer.save_state()
        except Exception as e:
            logger.error(f"An error occurred during training: {e}", exc_info=True)
            raise
        
    # breakpoint()

    # 8. Evaluate (Standard evaluation after training, if do_eval is true)
    if training_args.do_eval:
        logger.info("Starting standard evaluation...")
        try:
            eval_metrics = trainer.evaluate()  # Uses the cycling permutation collator by default
            trainer.log_metrics("eval", eval_metrics)
            trainer.save_metrics("eval", eval_metrics)
        except Exception as e:
            logger.error(f"An error occurred during standard evaluation: {e}", exc_info=True)
            # Continue to permutation-specific evaluation if possible

    # 9. Evaluate loss for all permutations and save to JSON
    if training_args.do_eval and eval_dataset and len(eval_dataset) > 0:
        logger.info("Starting evaluation of all permutations on the evaluation dataset...")
        try:
            # Pass eval_dataset explicitly to evaluate_all_permutations
            # It will use this dataset with each fixed permutation.
            all_perms_losses = trainer.evaluate_all_permutations(
                eval_dataset=eval_dataset, metric_key_prefix="final_eval_all_perms"
            )
            # breakpoint()

            if all_perms_losses:
                output_eval_file = os.path.join(training_args.output_dir, "all_permutations_eval_losses.json")
                with open(output_eval_file, "w") as writer:
                    json.dump(all_perms_losses, writer, indent=4)
                logger.info(f"Evaluation losses for all permutations saved to {output_eval_file}")

                # Optionally, log these to wandb as a summary table or individual metrics
                if wandb.run is not None:
                    wandb.log({"all_permutations_final_losses": all_perms_losses})
                    # For a table:
                    # try:
                    #     perm_loss_data = [[perm_idx, loss] for perm_idx, loss in all_perms_losses.items()]
                    #     table = wandb.Table(data=perm_loss_data, columns=["Permutation Index", "Loss"])
                    #     wandb.log({"final_permutation_loss_table": table})
                    # except Exception as e:
                    #     logger.warning(f"Failed to log permutation loss table to wandb: {e}")

            else:
                logger.warning("evaluate_all_permutations returned no losses.")

        except Exception as e:
            logger.error(f"An error occurred during the evaluation of all permutations: {e}", exc_info=True)
    elif not training_args.do_eval:
        logger.info("Skipping evaluation of all permutations because --do_eval is not set.")
    elif not eval_dataset or len(eval_dataset) == 0:
        logger.info("Skipping evaluation of all permutations because eval_dataset is empty or not available.")

    logger.info("Experiment finished.")
    if script_args.wandb_project:
        logger.info(f"Permutation loss data logged to wandb project: {script_args.wandb_project}")
    logger.info(f"Final model and metrics saved in: {training_args.output_dir}")
    if wandb.run is not None:
        wandb.finish()


if __name__ == "__main__":
    main()
