import gc
import os
import sys
from argparse import ArgumentParser
from pathlib import Path

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

sys.path.append(str(Path(__file__).parent.parent))
from settings.subsets import SUBSETS
from settings.text_transformations import TEXT_TRANSFORMATIONS
from utils.utils import find_empty, print_trainable_parameters, raw_values_batch


def data_processing(batch, tokenizer, model, suffix, calculate_loss, transformation):
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Używamy już podzielonych pól prefix/suffix
    if transformation is None:
        prefix_key = "prefix"
    else:
        prefix_key = "_".join(("prefix", transformation))
    prefixes = batch[prefix_key]
    suffixes = batch["suffix"]

    # Tokenize prefixes for model input
    inputs = tokenizer(prefixes, return_tensors="pt", padding=True).to(device)

    decoded_predicts = []
    with torch.no_grad():
        # Generate predictions
        generated = model.generate(
            **inputs,
            max_new_tokens=suffix,
            pad_token_id=tokenizer.pad_token_id,
            return_dict_in_generate=True,
            use_cache=True,
        )

        input_ids = inputs["input_ids"]
        generated_ids = generated.sequences
        predic_ids = [gen[input_ids.size(1) :] for gen in generated_ids]
        decoded_predicts = tokenizer.batch_decode(predic_ids, skip_special_tokens=True)

        output_dict = {
            "prefix": prefixes,
            "suffix": suffixes,
            "predic": decoded_predicts,
        }
        # Batch loss calculation
        if calculate_loss:
            full_texts = [prefix + suffix for prefix, suffix in zip(prefixes, suffixes)]
            loos = raw_values_batch(model, tokenizer, full_texts)
            output_dict["loss"] = loos

    return output_dict


def main(args):
    print("\nLoading model and tokenizer...", flush=True)
    if torch.cuda.is_available():
        free_mem = torch.cuda.mem_get_info()[0] / (1024**2)  # in MB
        total_mem = torch.cuda.mem_get_info()[1] / (1024**2)  # in MB
        print(f"CUDA memory: {free_mem:.2f} MB free / {total_mem:.2f} MB total")

    device = "cuda" if torch.cuda.is_available() else "cpu"

    if getattr(args, "use_quantization", False):
        quant_config = BitsAndBytesConfig(
            load_in_8bit=True,
        )
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name,
            device_map=(
                {"": torch.cuda.current_device()} if device.startswith("cuda") else None
            ),
            trust_remote_code=True,
            quantization_config=quant_config,
        ).eval()
    else:
        model = AutoModelForCausalLM.from_pretrained(
            args.model_name,
            device_map=(
                {"": torch.cuda.current_device()} if device.startswith("cuda") else None
            ),
            trust_remote_code=True,
        ).eval()

    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name, use_fast=True, trust_remote_code=True, padding_side="left"
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    print_trainable_parameters(model)

    safe_model_name = args.model_name.replace("/", "_")
    base_path = os.path.join(args.data_dir, safe_model_name, f"{args.subset_name}")

    # Select which files to process based on --max-val-train
    if getattr(args, "max_val_train", False):
        split_files = {
            "train": os.path.join(base_path, "train_full.jsonl"),
            "val": os.path.join(base_path, "train_val.jsonl"),
        }
    else:
        split_files = {
            "train": os.path.join(base_path, "train.jsonl"),
            "test": os.path.join(base_path, "test.jsonl"),
            "val": os.path.join(base_path, "val.jsonl"),
        }

    # Load only existing files
    splits = {k: v for k, v in split_files.items() if os.path.exists(v)}
    print(f"Found split files: {split_files}")

    if not splits:
        print(f"No split files found for {args.subset_name}")
        return

    print("\nFetching data...", flush=True)
    datasets = {}
    for split, path in splits.items():
        datasets[split] = load_dataset("json", data_files={split: path}, split=split)
        print(f"{split.capitalize()} set size: {len(datasets[split])}")

    print("\nProcessing data...")
    if getattr(args, "use_transformations", False):
        if args.text_transformation is None:
            text_transformation_str = "original"
        else:
            text_transformation_str = args.text_transformation
        save_path_pref = os.path.join(
            args.data_dir,
            "processed_predic",
            safe_model_name,
            f"{args.subset_name}",
            text_transformation_str,
        )
    else:
        save_path_pref = os.path.join(
            args.data_dir, "processed_predic", safe_model_name, f"{args.subset_name}"
        )

    sample_size = min([len(ds) for ds in datasets.values()] + [2000])
    print("Sample size: ", sample_size, flush=True)

    if getattr(args, "calculate_loss", False):
        calculate_loss = True
    else:
        calculate_loss = False

    for split, ds in datasets.items():
        ds = ds.select(range(sample_size)).map(
            lambda batch: data_processing(
                batch,
                tokenizer,
                model,
                suffix=args.n_suff,
                calculate_loss=calculate_loss,
                transformation=args.text_transformation,
            ),
            batched=True,
            batch_size=args.batch_size,
            remove_columns=ds.column_names,
        )
        find_empty(ds, message=f"{args.subset_name} {split} set")
        save_path = os.path.join(
            save_path_pref, f"{split}_{args.n_suff}suff_predic.jsonl"
        )
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        ds.to_json(save_path)
        print(f"{split.capitalize()} split saved to: {save_path}")


if __name__ == "__main__":
    parser = ArgumentParser(description="A placeholder script for tmp.py")

    parser.add_argument(
        "--data_dir",
        type=str,
        required=True,
        help="Directory containing the data files",
    )
    parser.add_argument(
        "--model_name", type=str, default="EleutherAI/pythia-2.8b-deduped"
    )
    parser.add_argument("--n_suff", type=int, default=64)  # 64, 32
    parser.add_argument(
        "--batch_size", type=int, default=32, help="Batch size for processing"
    )
    parser.add_argument(
        "--max-val-train",
        action="store_true",
        help="If set, process files generated by split_dataset.py with --max-val-train: train_full.jsonl, train_val.jsonl, test.jsonl, test_val.jsonl. Otherwise, process train.jsonl, val.jsonl, test.jsonl.",
    )
    parser.add_argument(
        "--calculate-loss", action="store_true", help="Should loss be calculated?"
    )
    parser.add_argument(
        "--use-transformations",
        action="store_true",
        help="Should transformed prefixes be used?",
    )
    parser.add_argument(
        "--transformation-type",
        type=str,
        choices=TEXT_TRANSFORMATIONS,
        default=None,
        help="Type of text transformation to apply if --use-transformations is set.",
    )
    parser.add_argument("--use-quantization", action="store_true", help="TODO")
    parser.add_argument(
        "--subset-name", 
        type=str, 
        default="all", 
        help="Specific subset to process. If 'all', processes all subsets from SUBSETS list."
    )

    args = parser.parse_args()
    print(args, flush=True)

    TEXT_TRANSFORMATIONS.append(None)
    # Do calculations for freelaw (soem text transformations only)

    if args.transformation_type is not None:
        TEXT_TRANSFORMATIONS = [args.transformation_type]

    # Set SUBSETS based on subset_name argument
    if args.subset_name != "all":
        SUBSETS = [args.subset_name]

    for i, subset in enumerate(SUBSETS):
        args.subset_name = subset
        if getattr(args, "use_transformations", False):
            for transformation in TEXT_TRANSFORMATIONS:
                args.text_transformation = transformation
                print("-" * 64)
                print(
                    f"Processing subset: {subset} with text transformation {transformation}",
                    flush=True,
                )
                main(args)
        else:
            args.text_transformation = None
            print("-" * 64)
            print(f"Processing subset: {subset}", flush=True)
            main(args)

        gc.collect()
        torch.cuda.empty_cache()

        print("\n", flush=True)
        print(f"{subset:<20} DONE | ({(i+1)/len(SUBSETS):.1%})", flush=True)
