import argparse
import os
import random
import numpy as np
import torch
import json
import wandb
from transformers import GPT2LMHeadModel, GPT2Config, AutoTokenizer, HfArgumentParser
import dataclasses
import datetime

from loader.data_collator import PermutationExperimentDataCollator

from trainer.permutation_loss_logging_trainer import (
    PermutationLossLoggingTrainer,
    PermutationLossLoggingTrainingArguments,
)
from main_permutation_loss_analysis import ScriptArguments as BaseScriptArguments, TextContinuationDataset
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any, Tuple

import logging
import math
import itertools

logger = logging.getLogger(__name__)

EVAL_SETUP = {}


def setup_evaluation_environment(
    script_args: BaseScriptArguments, training_args: PermutationLossLoggingTrainingArguments
):
    if EVAL_SETUP and "model_config_params" in EVAL_SETUP:
        return
    logger.info("Setting up evaluation environment...")
    EVAL_SETUP["script_args"] = script_args
    EVAL_SETUP["device"] = training_args.device

    from data.tokenizers import set_tokenizer, set_vocab
    vocab = set_vocab(0, field="ZZ", max_coeff=500, max_degree=1, continuous_coefficient=False, continuous_exponent=False)
    tokenizer = set_tokenizer(vocab)
    if not hasattr(tokenizer, "unk_token") or tokenizer.unk_token is None:
        tokenizer.add_special_tokens({"unk_token": "[UNK]"})
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token if tokenizer.eos_token is not None else "[PAD]"
        if tokenizer.pad_token == "[PAD]":
            tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    EVAL_SETUP["tokenizer"] = tokenizer

    model_config_params = {
        "vocab_size": len(tokenizer),
        "n_positions": script_args.max_seq_length,
        "n_ctx": script_args.max_seq_length,
        "n_embd": script_args.gpt2_n_embd,
        "n_layer": script_args.gpt2_n_layer,
        "n_head": script_args.gpt2_n_head,
        "bos_token_id": tokenizer.bos_token_id,
        "eos_token_id": tokenizer.eos_token_id,
        "pad_token_id": tokenizer.pad_token_id,
    }
    EVAL_SETUP["model_config_params"] = model_config_params
    EVAL_SETUP["target_len"] = script_args.target_len

    train_file_path = f"{script_args.dataset_path_prefix}.train"
    train_dataset = TextContinuationDataset(tokenizer, train_file_path, max_length=script_args.max_seq_length, data_has_colon_separator=True)
    EVAL_SETUP["train_dataset"] = train_dataset

    eval_file_path = f"{script_args.dataset_path_prefix}.test"
    eval_dataset = TextContinuationDataset(tokenizer, eval_file_path, max_length=script_args.max_seq_length, data_has_colon_separator=True)
    EVAL_SETUP["eval_dataset"] = eval_dataset

    EVAL_SETUP["full_training_args_template"] = dataclasses.replace(training_args)

    minimal_eval_output_dir = os.path.join(training_args.output_dir, "internal_evals_temp")
    os.makedirs(minimal_eval_output_dir, exist_ok=True)
    minimal_eval_args_template = PermutationLossLoggingTrainingArguments(
        output_dir=minimal_eval_output_dir,
        per_device_eval_batch_size=training_args.per_device_eval_batch_size,
        dataloader_num_workers=training_args.dataloader_num_workers,
        report_to=[],
        fp16=training_args.fp16,
        remove_unused_columns=False,
        dataloader_pin_memory = False
    )
    EVAL_SETUP["minimal_eval_args_template"] = minimal_eval_args_template
    logger.info("Evaluation environment setup complete.")

def train_model_for_generation(
    generation_permutations_as_lists: List[List[int]], generation_num_str: str
) -> Tuple[GPT2LMHeadModel, int]:
    logger.info(f"Starting model training for step/gen '{generation_num_str}' using {len(generation_permutations_as_lists)} permutations.")
    script_args: BaseScriptArguments = EVAL_SETUP["script_args"]
    model_config_params = EVAL_SETUP["model_config_params"]
    tokenizer = EVAL_SETUP["tokenizer"]
    train_dataset = EVAL_SETUP["train_dataset"]
    training_args_template: PermutationLossLoggingTrainingArguments = EVAL_SETUP["full_training_args_template"]
    target_len = EVAL_SETUP["target_len"]
    training_args_template.dataloader_pin_memory = False


    fresh_model_config = GPT2Config(**model_config_params)
    fresh_model = GPT2LMHeadModel(fresh_model_config)
    if fresh_model.config.vocab_size != len(tokenizer):
        fresh_model.resize_token_embeddings(len(tokenizer))

    gen_train_output_dir = os.path.join(training_args_template.output_dir, f"train_step_{generation_num_str}_{random.randint(1000,9999)}")
    os.makedirs(gen_train_output_dir, exist_ok=True)

    current_train_args = dataclasses.replace(training_args_template)
    current_train_args.output_dir = gen_train_output_dir
    current_train_args.logging_dir = os.path.join(gen_train_output_dir, "logs")
    current_train_args.report_to = []
    current_train_args.remove_unused_columns = False

    unique_perms_as_tuples = sorted(list(set(tuple(p) for p in generation_permutations_as_lists)))
    unique_perms_as_lists = [list(p) for p in unique_perms_as_tuples]

    logger.info(f"Step/Gen '{generation_num_str}': Original {len(generation_permutations_as_lists)} perms, using {len(unique_perms_as_lists)} unique for training.")

    if not unique_perms_as_lists:
        logger.warning(f"Step/Gen '{generation_num_str}': No unique permutations to train on. Returning un-trained model.")
        return fresh_model, 0

    unique_perms_as_tensors = perms_to_tensor_list(unique_perms_as_lists, target_len)
    train_data_collator = PermutationExperimentDataCollator(
        tokenizer=tokenizer,
        permutations_list=unique_perms_as_tensors,
        input_prefix_len=script_args.input_prefix_len,
        apply_permutation_to_target_only=True,
        per_sample_permutation=True if len(unique_perms_as_tensors) > 0 else False,
    )

    trainer = PermutationLossLoggingTrainer(
        model=fresh_model,
        args=current_train_args,
        train_dataset=train_dataset,
        tokenizer=tokenizer,
        data_collator=train_data_collator,
    )

    try:
        trainer.train()
        unwrapped_model = trainer.accelerator.unwrap_model(trainer.model)
        unwrapped_model.to(trainer.args.device)
        return unwrapped_model, len(unique_perms_as_lists)
    except Exception as e_train:
        logger.error(f"Error during model training for step/gen '{generation_num_str}': {e_train}", exc_info=True)
        raise

def evaluate_permutations(
    model: GPT2LMHeadModel,
    permutations_to_eval: List[List[int]],
    eval_id_prefix: str,
) -> List[Dict[str, Any]]:
    evaluated_results = []
    for i, p_list in enumerate(permutations_to_eval):
        loss_tuple = evaluate_individual_on_trained_model(
            trained_model=model,
            individual_permutation_as_list=p_list,
            individual_id_str=f"{eval_id_prefix}_{i}",
        )
        evaluated_results.append({"permutation": p_list, "loss": loss_tuple[0]})
    evaluated_results.sort(key=lambda x: x["loss"])
    return evaluated_results

def evaluate_individual_on_trained_model(
    trained_model: GPT2LMHeadModel,
    individual_permutation_as_list: List[int],
    individual_id_str: str,
) -> Tuple[float,]:
    script_args: BaseScriptArguments = EVAL_SETUP["script_args"]
    tokenizer = EVAL_SETUP["tokenizer"]
    eval_dataset = EVAL_SETUP["eval_dataset"]
    eval_args_template: PermutationLossLoggingTrainingArguments = EVAL_SETUP["minimal_eval_args_template"]
    target_len = EVAL_SETUP["target_len"]

    individual_perm_tensor = perms_to_tensor_list([individual_permutation_as_list], target_len)

    eval_data_collator_individual = PermutationExperimentDataCollator(
        tokenizer=tokenizer,
        permutations_list=individual_perm_tensor,
        input_prefix_len=script_args.input_prefix_len,
        apply_permutation_to_target_only=True,
        per_sample_permutation=False,
        fixed_permutation_index=0,
    )

    current_eval_args = dataclasses.replace(eval_args_template)
    eval_output_dir = os.path.join(current_eval_args.output_dir, f"eval_{individual_id_str}_{random.randint(1000,9999)}")
    os.makedirs(eval_output_dir, exist_ok=True)
    current_eval_args.output_dir = eval_output_dir

    eval_trainer = PermutationLossLoggingTrainer(
        model=trained_model,
        args=current_eval_args,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=eval_data_collator_individual,
    )

    try:
        metrics = eval_trainer.evaluate(metric_key_prefix=f"eval_{individual_id_str}")
        loss = metrics.get("eval_loss") or next((v for k, v in metrics.items() if k.endswith("_loss")), float("inf"))
        return (loss,)
    except Exception as e_eval:
        logger.error(f"Error during evaluation of individual {individual_id_str}: {e_eval}", exc_info=True)
        return (float("inf"),)

def perms_to_tensor_list(perms_list_of_lists: List[List[int]], target_len: int) -> List[torch.Tensor]:
    tensor_list = []
    for p_list in perms_list_of_lists:
        matrix = torch.zeros((target_len, target_len), dtype=torch.float32)
        for i, p_i in enumerate(p_list):
            matrix[i, p_i] = 1.0
        tensor_list.append(matrix)
    return tensor_list
