import argparse
import torch
from dataloader import JsonlDataset, JsonDataset
from feature_config import get_feature_config
import os
import time
from tqdm import tqdm
from util_funcs import SaeFeature


def parse():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--model_name_or_path",
        type=str,
        default="gpt2",
        help="The model name or path",
    )

    parser.add_argument(
        "--sae_model_name_or_paths",
        type=str,
        nargs="+",
        default=[],
        help="The model name or path",
    )

    parser.add_argument(
        "--topk_for_token_agg",
        type=int,
        default=-1,
        help="Top-K for token aggregation",
    )

    parser.add_argument(
        "--input_files",
        type=str,
        nargs="+",
        help="Input file",
    )

    parser.add_argument(
        "--batch_size",
        type=int,
        default=1,
    )

    parser.add_argument(
        "--heldout_data",
        action="store_true",
        help="The data used for heldout or not",
    )

    parser.add_argument(
        "--method",
        type=str,
        help="Method to compute semantic representations",
    )

    parser.add_argument(
        "--max_length",
        type=int,
        default=1024,
    )

    parser.add_argument(
        "--avg_level",
        type=str,
        default="sample",
    )

    parser.add_argument(
        "--block_size",
        type=float,
        default=1.0,
        help="Block size for the input data (GB)",
    )

    parser.add_argument(
        "--log_interval",
        type=int,
        default=1,
    )

    parser.add_argument(
        "--data_format",
        type=str,
        default="jsonl",
    )

    parser.add_argument(
        "--chat_template",
        type=str,
        default=None,
        help="Apply chat template to the input data",
    )

    parser.add_argument(
        "--output_path",
        type=str,
        default=None,
        help="Output path",
    )
    parser.add_argument(
        "--dialog_format",
        type=str,
        default="messages,system,user,assistant,role,content",
        help="Key to extract dialog from the input data",
    )

    parser.add_argument(
        "--start_line",
        type=int,
        default=0,
    )

    parser.add_argument(
        "--end_line",
        type=int,
        default=-1,
    )

    return parser.parse_args()


def tensor_memory_size(tensor):
    num_elements = tensor.numel()
    element_size = tensor.element_size()
    total_size_bytes = num_elements * element_size

    size = total_size_bytes

    # to GB
    size /= 1024**3
    return size


def batch_memory_size(batch):
    if isinstance(batch, list):
        size = 0
        for example in batch:
            if isinstance(example, SaeFeature):
                size += tensor_memory_size(example.acts)
                size += tensor_memory_size(example.indices)
    elif isinstance(batch, torch.Tensor):
        size = tensor_memory_size(batch)
    return size


# def write_similarity(f, similarity):
#     if isinstance(similarity, torch.Tensor):
#         similarity = similarity.tolist()
#     doc_num = len(similarity)
#     for sim in similarity:
#         f.write(f"{' '.join(map(str, sim))}\n")

#     return doc_num


def save_block(block, block_index, args):
    save_file = os.path.join(
        args.output_path,
        f"block_lines_{args.start_line}-{args.end_line}_{block_index}.pt",
    )
    if isinstance(block[0], SaeFeature):
        save_data = {
            "acts": [sae_feature.acts for sae_feature in block],
            "indices": [sae_feature.indices for sae_feature in block],
        }
    elif isinstance(block[0], torch.Tensor):
        save_data = torch.cat(block, dim=0)
    torch.save(save_data, save_file)
    print(f"Saving block data to {save_file}")


def main():
    args = parse()

    if not args.heldout_data:
        os.makedirs(args.output_path, exist_ok=True)
    else:
        os.makedirs(os.path.dirname(args.output_path), exist_ok=True)

    ds_config = get_feature_config(**vars(args))

    heldout_data_feature = dict()

    for input_file in args.input_files:
        if not os.path.exists(input_file):
            raise FileNotFoundError(f"{input_file} not found")
        print(f"Processing file: \033[94m{input_file}\033[0m")

        # build dataloader
        dataloader = torch.utils.data.DataLoader(
            (
                JsonlDataset([input_file])
                if args.data_format == "jsonl"
                else JsonDataset(
                    [input_file], start_line=args.start_line, end_line=args.end_line
                )
            ),
            batch_size=args.batch_size,
            collate_fn=ds_config.collate_func,
            shuffle=False,
            pin_memory=True,
            num_workers=1 if args.data_format == "jsonl" else 8,
        )

        if args.heldout_data:
            cur_task_heldout_data_feature = None
            total_samples = 0
            for batch in tqdm(dataloader):
                batch_heldout_data_feature = ds_config.compute_feature(batch)
                cur_task_heldout_data_feature, total_samples = (
                    ds_config.heldout_data_feature_agg_func(
                        cur_task_heldout_data_feature,
                        batch_heldout_data_feature,
                        total_samples,
                    )
                )

            heldout_data_feature[os.path.basename(input_file)] = (
                cur_task_heldout_data_feature
            )
        else:
            step = 0
            doc_num = 0
            time_flag = time.time()
            total_doc_num = len(dataloader.dataset)
            block = []
            block_size = 0
            block_index = 0
            for batch in dataloader:
                batch_feature = ds_config.compute_feature(batch)
                if isinstance(batch_feature, list):
                    block += batch_feature
                elif isinstance(batch_feature, torch.Tensor):
                    block.append(batch_feature)
                block_size += batch_memory_size(batch_feature)

                if block_size >= args.block_size:
                    save_block(block, block_index, args)
                    block = []
                    block_size = 0
                    block_index += 1

                step += 1
                if isinstance(batch_feature, list):
                    doc_num += len(batch_feature)
                elif isinstance(batch_feature, torch.Tensor):
                    doc_num += batch_feature.size(0)

                if step % args.log_interval == 0:
                    speed = doc_num / (time.time() - time_flag)
                    left_time = (total_doc_num - doc_num) / speed

                    # transform left_time to hours, minutes, seconds
                    left_time = int(left_time)
                    hours = left_time // 3600
                    minutes = (left_time % 3600) // 60
                    seconds = left_time % 60

                    print(
                        f"Processed \033[94m{doc_num}\033[0m documents, block size: {block_size:.2f}, speed: \033[92m{speed:.2f}\033[0m docs/s, left time: \033[93m{hours}h {minutes}m {seconds}s\033[0m",
                        flush=True,
                    )

            if block:
                save_block(block, block_index, args)

    if args.heldout_data:
        heldout_data_feature = ds_config.task_heldout_data_feature_agg_func(
            heldout_data_feature
        )
        for task in heldout_data_feature:
            if isinstance(heldout_data_feature[task], SaeFeature):
                heldout_data_feature[task] = {
                    "acts": heldout_data_feature[task].acts,
                    "indices": heldout_data_feature[task].indices,
                }
        torch.save(heldout_data_feature, args.output_path)


if __name__ == "__main__":
    main()
