import argparse
import os
import torch
import json
from tqdm import tqdm
from transformers import GPT2LMHeadModel, TrainingArguments, Trainer, GPT2Config
from functools import partial
from loader.data import _load_data
from torch.utils.data import DataLoader

from utils.eval import collate_fn as default_original_collate_fn
from loader.checkpoint import load_tokenizer


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def permutation_collate_fn(batch, tokenizer, permutation_matrix, input_len, original_collate_fn):
    processed_batch = original_collate_fn(batch, tokenizer)

    input_ids_full = processed_batch["input_ids"]
    attention_mask_full = processed_batch["attention_mask"]

    B, L_full = input_ids_full.shape

    eos_token_id_tensor = input_ids_full[:, -1:]  # [B, 1]
    if input_len >= L_full - 1:
        labels = input_ids_full.clone()
        return {"input_ids": input_ids_full, "attention_mask": attention_mask_full, "labels": labels}

    prefix = input_ids_full[:, :input_len] 
    target_tokens = input_ids_full[:, input_len:-1]  # [B, target_len_actual]
    eos_tokens = input_ids_full[:, -1:]  # [B, 1]

    target_len_actual = target_tokens.shape[1]

    if target_len_actual <= 0:
        labels = input_ids_full.clone()
        return {"input_ids": input_ids_full, "attention_mask": attention_mask_full, "labels": labels}

    current_permutation_matrix = permutation_matrix
    if permutation_matrix.shape[0] != target_len_actual or permutation_matrix.shape[1] != target_len_actual:
        # print(f"Adjusting perm matrix from {permutation_matrix.shape} to {target_len_actual}x{target_len_actual}")
        current_permutation_matrix = permutation_matrix[:target_len_actual, :target_len_actual]

    if current_permutation_matrix.device != target_tokens.device:
        if current_permutation_matrix.is_cuda:
            current_permutation_matrix = current_permutation_matrix.cpu()
        if target_tokens.is_cuda:  # Should be on CPU
            target_tokens = target_tokens.cpu()
            prefix = prefix.cpu()
            eos_tokens = eos_tokens.cpu()
        current_permutation_matrix = current_permutation_matrix.to(target_tokens.device)

    order_idx_single = current_permutation_matrix.argmax(dim=-1)  # [target_len_actual]

    if order_idx_single.shape[0] != target_len_actual:
        labels = input_ids_full.clone()
        return {"input_ids": input_ids_full, "attention_mask": attention_mask_full, "labels": labels}

    if not torch.all(order_idx_single < target_len_actual):
        labels = input_ids_full.clone()
        return {"input_ids": input_ids_full, "attention_mask": attention_mask_full, "labels": labels}

    permuted_target_tokens = torch.zeros_like(target_tokens)
    for b_idx in range(B):
        permuted_target_tokens[b_idx] = target_tokens[b_idx, order_idx_single]

    permuted_input_ids = torch.cat([prefix, permuted_target_tokens, eos_tokens], dim=1)
    labels = permuted_input_ids.clone()
    labels[:, :input_len] = -100 
    return {"input_ids": permuted_input_ids, "attention_mask": attention_mask_full, "labels": labels}


def main():
    parser = argparse.ArgumentParser(
        description="Train a model with different target permutations based on sparsity ranking using Hugging Face Trainer."
    )
    parser.add_argument(
        "--permutation_results_path",
        type=str,
        required=True,
        help="Path to the .pt file containing permutation results.",
    )
    parser.add_argument("--base_model_path", type=str, required=True, help="Path to the base GPT-2 model directory.")
    parser.add_argument("--dataset_path", type=str, required=True, help="Path to the training dataset (e.g., .jsonl).")
    parser.add_argument("--model_save_dir", type=str, required=True, help="Directory to save trained models.")
    parser.add_argument(
        "--input_len",
        type=int,
        required=True,
        help="Length of the input prefix (not permuted). Must match the one used for generating permutations.",
    )

    parser.add_argument("--num_epochs", type=int, default=3, help="Number of training epochs for each permutation.")
    parser.add_argument("--batch_size", type=int, default=4, help="Training batch size per device.")
    parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate.")
    parser.add_argument(
        "--warmup_steps", type=int, default=0, help="Number of warmup steps for learning rate scheduler."
    )
    parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay.")
    parser.add_argument("--logging_steps", type=int, default=50, help="Log every X updates steps.")
    parser.add_argument(
        "--save_steps",
        type=int,
        default=500,
        help="Save checkpoint every X updates steps (used for save_total_limit). Not saving per step here, but after epochs.",
    )
    parser.add_argument(
        "--save_total_limit",
        type=int,
        default=1,
        help="Limit the total amount of checkpoints. Deletes the older checkpoints.",
    )
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--fp16", action="store_true", help="Whether to use 16-bit (mixed) precision training.")

    # Generation parameters (for model.generate)
    parser.add_argument(
        "--generation_max_length",
        type=int,
        default=None,
        help="Maximum length for generation. Overrides model config if set.",
    )

    args = parser.parse_args()

    print(f"Using device: {DEVICE}")
    os.makedirs(args.model_save_dir, exist_ok=True)

    try:
        perm_results = torch.load(args.permutation_results_path, map_location="cpu")
        sparsity_info_list = perm_results["sparsity_info"]
    except Exception as e:
        print(f"Error loading permutation results from {args.permutation_results_path}: {e}")
        return

    sparsity_info_list.sort(key=lambda x: x["rank"])

    print(f"Loading training dataset from {args.dataset_path}...")
    try:
        train_dataset = _load_data(f"{args.dataset_path}.train")
        eval_dataset = _load_data(f"{args.dataset_path}.test")
        if not train_dataset:
            print("Dataset is empty. Exiting.")
            return
    except Exception as e:
        print(f"Error loading dataset: {e}")
        return

    for perm_info in tqdm(sparsity_info_list, desc="Overall Permutation Ranks Progress"):
        # breakpoint()
        rank = perm_info["rank"]
        permutation_matrix = perm_info["permutation_matrix"]

        print(f"\n--- Starting training for rank {rank} ---")
        print(f"Using permutation matrix of shape: {permutation_matrix.shape}")

        try:
            tokenizer = load_tokenizer(args.base_model_path)
            config = GPT2Config(
                vocab_size=tokenizer.vocab_size,
                n_positions=args.input_len * 2 + 1,
                n_head=1,
                n_layer=6,
                n_embd=512,
                n_inner=2048,
                bos_token_id=tokenizer.bos_token_id,
                eos_token_id=tokenizer.eos_token_id,
            )
            model = GPT2LMHeadModel(config)

            print(f"Initialized a new GPT2 model using config from {args.base_model_path} for rank {rank}.")

        except Exception as e:
            print(f"Error initializing model or loading tokenizer for rank {rank}: {e}")
            continue

        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
            model.config.pad_token_id = tokenizer.eos_token_id

        model.to(DEVICE)

        current_collate_fn = partial(
            permutation_collate_fn,
            tokenizer=tokenizer,
            permutation_matrix=permutation_matrix,
            input_len=args.input_len,
            original_collate_fn=default_original_collate_fn,
        )

        save_dir_rank = os.path.join(args.model_save_dir, f"model_rank_{rank}")
        os.makedirs(save_dir_rank, exist_ok=True)

        training_args = TrainingArguments(
            output_dir=save_dir_rank, 
            num_train_epochs=args.num_epochs,
            per_device_train_batch_size=args.batch_size,
            learning_rate=args.learning_rate,
            # warmup_steps=args.warmup_steps,
            # weight_decay=args.weight_decay,
            logging_dir=os.path.join(save_dir_rank, "logs"),
            logging_steps=args.logging_steps,
            save_strategy="epoch", 
            # save_steps=args.save_steps,
            save_total_limit=args.save_total_limit,
            # gradient_accumulation_steps=args.gradient_accumulation_steps,
            fp16=args.fp16,
            report_to="none", 
            remove_unused_columns=False,
            label_names=["labels"], 
            # deepspeed=args.deepspeed_config if args.deepspeed_config else None,
        )

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            data_collator=current_collate_fn,
        )
        # breakpoint()

        print(f"Starting training with Trainer for rank {rank}...")
        try:
            trainer.train()
            trainer.save_model()  # final model
            print(f"Finished training with Trainer for rank {rank}.")

            # --- Evaluation and result saving ---
            if eval_dataset:
                print(f"Starting evaluation for rank {rank} using model.generate()...")
                model.eval()  # Set to evaluation mode

                # Dataloader for evaluation (use the same collate_fn as for training)
                # Similar to `get_eval_dataloader` inside Trainer
                eval_dataloader = DataLoader(
                    eval_dataset,
                    batch_size=args.batch_size,  # Use the batch size for training, or specify a separate batch size for evaluation
                    collate_fn=current_collate_fn,
                    shuffle=False,
                )

                all_decoded_preds = []
                all_decoded_labels = []  # List to store reference labels
                all_prompts_text = []  # List to store prompt text
                correct_predictions = 0  # For exact match
                total_predictions = 0  # For exact match

                all_eval_losses = []  # Store evaluation loss for each batch
                total_correct_tokens = 0  # Number of correct tokens in token units
                total_valid_tokens = 0  # Number of tokens to be compared in token units

                # Generation config (same as last time)
                max_length = args.generation_max_length if args.generation_max_length else 50
                generation_config_kwargs = {
                    "max_length": max_length,
                    "pad_token_id": tokenizer.pad_token_id,
                    "eos_token_id": tokenizer.eos_token_id,
                }

                with torch.no_grad():  # No gradient calculation required
                    for batch in tqdm(eval_dataloader, desc=f"Evaluating Rank {rank}"):
                        # Preparation for loss calculation
                        # Send input_ids, attention_mask, labels returned by collate_fn to DEVICE
                        batch_on_device = {k: v.to(DEVICE) for k, v in batch.items() if hasattr(v, "to")}

                        # 1. Calculation of evaluation loss
                        outputs = model(
                            input_ids=batch_on_device["input_ids"],
                            attention_mask=batch_on_device["attention_mask"],
                            labels=batch_on_device["labels"],
                        )
                        batch_loss = outputs.loss
                        if batch_loss is not None:
                            all_eval_losses.append(batch_loss.item())

                        # 2. Preparation for text generation and exact match rate calculation (reuse existing logic)
                        input_ids_full_batch = batch["input_ids"].to(DEVICE)  # Get again from the CPU version of batch
                        input_ids_prompt = input_ids_full_batch[:, : args.input_len]
                        attention_mask_prompt = batch["attention_mask"][:, : args.input_len].to(DEVICE)

                        prompts_decoded_batch = tokenizer.batch_decode(input_ids_prompt, skip_special_tokens=True)
                        all_prompts_text.extend(prompts_decoded_batch)

                        reference_label_ids_full = batch["labels"].clone()
                        reference_label_ids_full[reference_label_ids_full == -100] = tokenizer.pad_token_id
                        decoded_labels_batch = tokenizer.batch_decode(
                            reference_label_ids_full[:, args.input_len :], skip_special_tokens=True
                        )
                        all_decoded_labels.extend(decoded_labels_batch)

                        generated_ids_batch = model.generate(
                            input_ids_prompt, attention_mask=attention_mask_prompt, **generation_config_kwargs
                        )
                        decoded_preds_batch = tokenizer.batch_decode(
                            generated_ids_batch[:, args.input_len :], skip_special_tokens=True
                        )
                        all_decoded_preds.extend(decoded_preds_batch)

                        total_predictions += len(decoded_preds_batch)  # Number of exact match samples
                        for pred, label in zip(decoded_preds_batch, decoded_labels_batch):
                            if pred.strip() == label.strip():
                                correct_predictions += 1

                        # 3. Calculation of accuracy in token units
                        # The output of model.generate (generated_ids_batch) also includes the prompt part
                        # The comparison target is after the prompt
                        generated_target_tokens = generated_ids_batch[:, args.input_len :]
                        # The reference label is also after the prompt (since the prefix of labels is -100, it can be used as is)
                        # batch["labels"] is on the CPU, so it needs to be sent to the DEVICE
                        reference_target_tokens = batch["labels"][:, args.input_len :].to(DEVICE)

                        # Adjust the length to the shorter one
                        len_gen = generated_target_tokens.shape[1]
                        len_ref = reference_target_tokens.shape[1]
                        compare_len = min(len_gen, len_ref)

                        gen_compare = generated_target_tokens[:, :compare_len]
                        ref_compare = reference_target_tokens[:, :compare_len]

                        # Create a mask to exclude pad_token_id and -100 (prefix mask) from comparison
                        # Only ref_compare may have pad_token_id or -100
                        valid_token_mask = (ref_compare != tokenizer.pad_token_id) & (ref_compare != -100)

                        # Calculate the match in the tokens to be actually compared
                        correct_tokens_in_batch = (gen_compare == ref_compare) & valid_token_mask

                        total_correct_tokens += correct_tokens_in_batch.sum().item()
                        total_valid_tokens += valid_token_mask.sum().item()

                # After the loop, calculate the average loss and various accuracy rates
                avg_eval_loss = sum(all_eval_losses) / len(all_eval_losses) if all_eval_losses else 0
                exact_match_accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
                token_accuracy = total_correct_tokens / total_valid_tokens if total_valid_tokens > 0 else 0

                print(f"Rank {rank} - Average Eval Loss: {avg_eval_loss:.4f}")
                print(
                    f"Rank {rank} - Exact Match Accuracy: {exact_match_accuracy:.4f} ({correct_predictions}/{total_predictions})"
                )
                print(
                    f"Rank {rank} - Token Accuracy: {token_accuracy:.4f} ({total_correct_tokens}/{total_valid_tokens})"
                )

                # Save generated text (existing logic)
                if all_decoded_preds:
                    results_save_path = os.path.join(save_dir_rank, f"eval_results_rank_{rank}_generate.jsonl")
                    with open(results_save_path, "w", encoding="utf-8") as f_out:
                        # Save the original input (prompt) and reference target together
                        for i in range(len(all_decoded_preds)):
                            # Corrected part
                            prompt_text_to_save = all_prompts_text[i] if i < len(all_prompts_text) else "N/A"
                            reference_target_to_save = all_decoded_labels[i] if i < len(all_decoded_labels) else "N/A"
                            generated_text_to_save = all_decoded_preds[i] if i < len(all_decoded_preds) else "N/A"

                            f_out.write(
                                json.dumps(
                                    {
                                        "rank": rank,
                                        "id_in_dataset": i,  # Serial number in the entire dataset
                                        "prompt": prompt_text_to_save,
                                        "reference_target": reference_target_to_save,
                                        "generated_text": generated_text_to_save,
                                    }
                                )
                                + "\n"
                            )
                    print(f"Evaluation results for rank {rank} saved to {results_save_path}")

                # Save summary information
                summary_save_path = os.path.join(save_dir_rank, f"eval_summary_rank_{rank}.json")
                with open(summary_save_path, "w", encoding="utf-8") as f_summary:
                    json.dump(
                        {
                            "rank": rank,
                            "exact_match_accuracy": exact_match_accuracy,
                            "correct_predictions": correct_predictions,
                            "total_predictions": total_predictions,
                            "token_accuracy": token_accuracy,
                            "total_correct_tokens": total_correct_tokens,
                            "total_valid_tokens": total_valid_tokens,
                            "average_eval_loss": avg_eval_loss,
                            "generation_config": generation_config_kwargs,
                        },
                        f_summary,
                        indent=4,
                    )
                print(f"Evaluation summary for rank {rank} saved to {summary_save_path}")
            # --- Evaluation and result saving end here ---

            # Save the final model after training (Trainer saves every epoch/step, but you can also save it explicitly at the end)
            # trainer.save_model() # Saved to TrainingArguments's output_dir
            # It is expected that the tokenizer will also be saved together (specified in TrainingArguments)
            if tokenizer:
                tokenizer.save_pretrained(save_dir_rank)
            print(f"Model and tokenizer for rank {rank} saved to {save_dir_rank}")

        except Exception as e:
            print(f"Error during training for rank {rank}: {e}")
            import traceback

            traceback.print_exc()

        print(f"--- Finished processing rank {rank} ---")

    print("\nAll permutation-based trainings finished.")


if __name__ == "__main__":
    main()