import argparse
import datetime
import json
import os

import numpy as np
import pandas as pd
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
from tqdm import tqdm
from transformers import GPT2TokenizerFast

import sampling
import utils
from load_model import load_model


def _run(rank, args):
    if rank == 0:
        logger = utils.get_logger(os.path.join(args.work_dir, f"sedd_small_conditional_generation/evaluation_{args.type}_{args.steps}_logs"))

    def mprint(msg):
        if rank == 0:
            logger.info(msg)

    torch.cuda.set_device(rank)
    device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")

    mprint("================================")

    if device.type == "cuda":
        mprint("Found {} CUDA devices.".format(torch.cuda.device_count()))
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            mprint(
                "{} \t Memory: {:.2f}GB".format(
                    props.name, props.total_memory / (1024 ** 3)
                )
            )
    else:
        mprint("WARNING: Using device {}".format(device))
    mprint(f"Found {os.cpu_count()} total number of CPUs.")

    mprint("================================")

    args_dict = vars(args)
    for arg_name, arg_value in args_dict.items():
        mprint(f"{arg_name}: {arg_value}")

    mprint("================================")

    model, graph, noise = load_model(args.model_path, device)
    model = DDP(model, device_ids=[rank], static_graph=True)
    noise = DDP(noise, device_ids=[rank], static_graph=True)

    tokenizer = GPT2TokenizerFast.from_pretrained('assets/gpt2')
    tokenizer.padding_side = 'left'
    tokenizer.pad_token = tokenizer.eos_token

    data = pd.read_json(args.data_path, lines=True)

    def split_data(args, rank, data):
        rows = data.shape[0]
        parts = args.ngpus
        rows_per_part = rows // parts

        if rank == parts - 1:
            part_data = data.iloc[rank * rows_per_part:]
        else:
            part_data = data.iloc[rank * rows_per_part : (rank + 1) * rows_per_part]
        return part_data

    assert data.shape[0] % args.ngpus == 0
    data = split_data(args, rank, data)

    assert args.batch % 5 == 0 and data.shape[0] % args.batch == 0

    with torch.no_grad():
        predictions = []
        references_truncated = []

        if rank == 0:
            progress_bar = tqdm(data.iterrows(), total=len(data.index))
        else:
            progress_bar = data.iterrows()
        
        references = []
        for i, row in progress_bar:
            for _ in range(5):
                references.append(row.truncated_string)

            if len(references) < args.batch:
                continue

            assert len(references) == args.batch

            tokens = tokenizer(references, return_tensors='pt', padding=True)['input_ids'].to(device)
            tokens = tokens[:, : 100]

            for j in range(0, len(references), 5):
                references_truncated.append(tokenizer.decode(tokens[j]))

            references.clear()

            if args.type == 'standard':
                input_ids = tokens[:, : 50]
                input_locs = list(range(50))
            else:
                prefix_ids = tokens[:, : 25]
                suffix_ids = tokens[:, -25:]
                input_ids = torch.cat((prefix_ids, suffix_ids), dim=1)
                input_locs = list(range(prefix_ids.shape[1])) + list(range(100 - suffix_ids.shape[1], 100))
                    
            input_ids = input_ids.to(device=device)

            def proj_fun(x):
                x[:, input_locs] = input_ids
                return x

            sampling_fn = sampling.get_pc_sampler(
                graph, noise, (args.batch, 100), 'analytic', args.steps, device=device, proj_fun=proj_fun
            )

            samples = proj_fun(sampling_fn(model))

            predictions.extend(tokenizer.batch_decode(samples))

    save_lst = [{'context_string': item} for item in predictions]
    save_file_path = f'{args.save_file_path}/webtext_generated_data_{rank}_{args.type}_{args.steps}steps.jsonl'
    with open(save_file_path, 'w') as file:
        for item in save_lst:
            json.dump(item, file)
            file.write('\n')

    save_lst = [{'truncated_string': item} for item in references_truncated]
    save_file_path = f'{args.save_file_path}/webtext_truncated_data_{rank}_{args.type}_{args.steps}steps.jsonl'
    with open(save_file_path, 'w') as file:
        for item in save_lst:
            json.dump(item, file)
            file.write('\n')


def run(rank, args, port):
    def setup(rank, world_size, port):
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(port)

        # Initialize the process group
        dist.init_process_group(
            "nccl", rank=rank, world_size=world_size, timeout=datetime.timedelta(minutes=30)
        )

    def cleanup():
        dist.destroy_process_group()

    try:
        setup(rank, args.ngpus, port)
        _run(rank, args)
    finally:
        cleanup()


def main():
    parser = argparse.ArgumentParser(description="Evaluation On SEDD models")
    parser.add_argument("--batch", type=int, default=125)
    parser.add_argument("--model_path", type=str, default="louaaron/sedd-small")
    parser.add_argument("--ngpus", type=int, default=8)
    parser.add_argument("--steps", type=int, default=1024)
    parser.add_argument("--work_dir", type=str, default="./")
    parser.add_argument("--type", type=str, default="infill")
    parser.add_argument("--data_path", type=str, default="assets/datasets/webtext/webtext_truncated_data_infilling.jsonl")
    parser.add_argument("--save_file_path", type=str, default="sedd_small_conditional_generation")


    args = parser.parse_args()
    os.makedirs(args.work_dir, exist_ok=True)

    port = int(np.random.randint(10000, 20000))
    logger = utils.get_logger(os.path.join(args.work_dir, f"sedd_small_conditional_generation/evaluation_{args.type}_{args.steps}_logs"))

    try:
        mp.set_start_method("forkserver")
        mp.spawn(run, args=(args, port), nprocs=args.ngpus, join=True)
    except Exception as e:
        logger.critical(e, exc_info=True)


if __name__ == "__main__":
    main()