import os
from trainers import CharInpaintModelWrapper, CharInpaintTrainer, OCRAccLogger
from notebooks.notebook_utils import *
import torch
from argparse import ArgumentParser
from trainers import OCRAccLogger, CharInpaintTrainer
from mydatasets import CharInpaintDataset, char_inpaint_collate_fn
from torch.utils.data import DataLoader, Dataset
import torch.distributed as torchdist
import torch.multiprocessing as mp
from functools import partial
from omegaconf import OmegaConf
from mydatasets import CharInpaintDatasetTest
from pytorch_lightning import seed_everything
from datasets import concatenate_datasets
import socket


def find_free_network_port() -> int:
    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    s.bind(("", 0))
    port = s.getsockname()[1]
    s.close()
    return port


def create_parser():
    parser = ArgumentParser()
    parser.add_argument("--path", type=str, required=True)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--todo", )
    parser.add_argument("--log_dir", default="tmp_evaluate/")
    parser.add_argument("--project", type=str, default="eval_charinpaint")
    parser.add_argument("--dataset", required=True, type=str)
    parser.add_argument(
        "--base", type=str, nargs="*", metavar="config.yaml", default=list()
    )
    return parser


def init_process(mainport, rank, size, fn, backend='nccl'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    print(f"Rank {rank} set up main port {mainport}")
    os.environ['MASTER_PORT'] = str(mainport)
    print(f"Start initialize at rank {rank}")
    torchdist.init_process_group(backend, rank=rank, world_size=size)
    print(f"Initialize at rank {rank} done")
    torchdist.barrier()
    fn(rank=rank, worldsize=size)


def load_model(path, device="cpu"):
    model = CharInpaintTrainer.load_from_checkpoint(path, map_location="cpu")
    model = model.to(device)
    model.eval()
    return model


def cleanup(rank):
    torchdist.destroy_process_group()
    print(f"Rank {rank} is done.")


def setup_for_distributed(is_master):
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)
    __builtin__.print = print


REPLACE_WORDS = True


def run_evaluate(opt, config, device, rank, worldsize):
    # torch.cuda.set_device(rank)  # setup device
    torch.cuda.empty_cache()
    device = torch.device(f"cuda:{device}")

    keys = list(config['lightning']['callbacks']
                ['ocracc_logger']['params']['val_eval_conf']['dataconfs'].keys())
    for k in keys:
        config['lightning']['callbacks']['ocracc_logger']['params']['val_eval_conf']['dataconfs'].pop(
            k)
    keys = list(config['lightning']['callbacks']
                ['ocracc_logger']['params']['train_eval_conf']['dataconfs'].keys())
    for k in keys:
        config['lightning']['callbacks']['ocracc_logger']['params']['train_eval_conf']['dataconfs'].pop(
            k)

    # collect data
    ocracclogger = OCRAccLogger(
        **config['lightning']['callbacks']['ocracc_logger']['params'])
    tododataset = opt.todo
    val_dataset = None
    for todo in tododataset:
        tempdata = CharInpaintDatasetTest(
            "testdata", f"testdata/{todo}.json", rand_label=REPLACE_WORDS)
        if not todo in ["ICDAR13"]:
            if "Synthtiger" in todo:
                if len(tempdata) >= 1000:
                    tempdata.data = tempdata.data.select(range(1000))
            else:
                if len(tempdata) >= 1000:
                    tempdata.data = tempdata.data.select(range(1000))

        if val_dataset is None:
            val_dataset = tempdata
        else:
            print(type(val_dataset.data))
            val_dataset.data = concatenate_datasets(
                [val_dataset.data, tempdata.data])
    print(f"TODO number: {len(val_dataset)}")

    ocracclogger.val_eval = val_dataset
    torchdist.barrier()  # wait for others to collect data

    model = load_model(opt.path, device)
    torchdist.barrier()  # wait for others to load model

    seed_everything(opt.seed)
    # distributed generate
    for split in ['val']:
        # on train set
        toeval = ocracclogger.train_eval if split == 'train' else ocracclogger.val_eval
        out_dir = os.path.join(config['base_log_dir'], split)
        res_path = os.path.join(config['base_log_dir'], f"{split}_res.json")
        ocracclogger.generate_images(
            model, toeval, out_dir)
        torchdist.barrier()
        if rank == 0:  # evaluation
            ocracclogger.raw_ocr_eval(device, out_dir, res_path)
        torchdist.barrier()  # wait for main to evaluate
    cleanup(rank)


if __name__ == "__main__":
    parser = create_parser()
    opt = parser.parse_args()

    seed_everything(opt.seed)
    configs = [OmegaConf.load(b) for b in opt.base]
    for conf in configs:
        conf["base_log_dir"] = opt.log_dir
        OmegaConf.resolve(conf)
    config = OmegaConf.merge(*configs)
    devices = [0, 1]

    print(devices)
    worldsize = len(devices)
    os.environ['CUDA_VISIBLE_DEVICES'] = ",".join([str(d) for d in devices])
    processes = []
    mp.set_start_method("spawn")
    mainport = find_free_network_port()
    for rank in range(worldsize):
        fn = partial(run_evaluate, opt=opt, config=config, device=rank)
        p = mp.Process(target=init_process, args=(
            mainport, rank, worldsize, fn))
        p.start()
        processes.append(p)

    for p in processes:
        p.join()

    print("Evaluation Done")
