import datetime
import os
import os.path
import gc
from itertools import chain

import numpy as np
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F

import data
import losses
import sampling
import graph_lib
import noise_lib
import utils
from model import SEDD
from model.ema import ExponentialMovingAverage
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
from datasets import load_dataset
from torch.utils.data import DataLoader
import OT_expectation
import time


torch.backends.cudnn.benchmark = True
# torch.autograd.set_detect_anomaly(True)


def setup(rank, world_size, port):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(port)

    # initialize the process group
    dist.init_process_group(
        "nccl", rank=rank, world_size=world_size, timeout=datetime.timedelta(minutes=30)
    )


def cleanup():
    dist.destroy_process_group()


def run_multiprocess(rank, world_size, cfg, port):
    try:
        setup(rank, world_size, port)
        _run(rank, world_size, cfg)
    finally:
        cleanup()


def _run(rank, world_size, cfg):
    torch.cuda.set_device(rank)
    work_dir = cfg.work_dir

    # Create directories for experimental logs
    sample_dir = os.path.join(work_dir, "samples")
    checkpoint_dir = os.path.join(work_dir, "checkpoints")
    checkpoint_meta_dir = os.path.join(work_dir, "checkpoints-meta", "checkpoint.pth")
    if rank == 0:
        utils.makedirs(sample_dir)
        utils.makedirs(checkpoint_dir)
        utils.makedirs(os.path.dirname(checkpoint_meta_dir))

    # logging
    if rank == 0:
        logger = utils.get_logger(os.path.join(work_dir, "logs"))
    def mprint(msg):
        if rank == 0:
            logger.info(msg)

    mprint(work_dir)
    mprint(cfg)
    device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
    if device.type == "cuda":
        mprint("Found {} CUDA devices.".format(torch.cuda.device_count()))
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            mprint(
                "{} \t Memory: {:.2f}GB".format(
                    props.name, props.total_memory / (1024 ** 3)
                )
            )
    else:
        mprint("WARNING: Using device {}".format(device))
    mprint(f"Found {os.cpu_count()} total number of CPUs.")

    # build token graph
    graph = graph_lib.get_graph(cfg, device)
    
    # build score model
    score_model = SEDD(cfg).to(device)
    score_model = DDP(score_model, device_ids=[rank], static_graph=True, find_unused_parameters=True)

    num_parameters = sum(p.numel() for p in score_model.parameters())
    mprint(f"Number of parameters in the model: {num_parameters}")

    ema = ExponentialMovingAverage(
        score_model.parameters(), decay=cfg.training.ema)
    mprint(score_model)
    mprint(f"EMA: {ema}")

    # build noise
    noise = noise_lib.get_noise(cfg).to(device)
    noise = DDP(noise, device_ids=[rank], static_graph=True)
    sampling_eps = 1e-5


    # build optimization state
    optimizer = losses.get_optimizer(cfg, chain(score_model.parameters(), noise.parameters()))
    mprint(f"Optimizer: {optimizer}")
    scaler = torch.cuda.amp.GradScaler()
    mprint(f"Scaler: {scaler}")
    state = dict(optimizer=optimizer, scaler=scaler, model=score_model, noise=noise, ema=ema, step=0) 


    # load in state
    state = utils.restore_checkpoint(checkpoint_meta_dir, state, device)
    initial_step = int(state['step'])

    
    # Load the tokenizer.
    import json
    from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers
    from transformers import PreTrainedTokenizerFast

    custom_chars = [
        ' ', 'e', 't', 'o', 'a', 'h', 'n', 's', 'r', 'i', 'l', 'd', '\n', 'u', 'm',
        'y', ',', '.', 'w', 'f', 'c', 'g', 'I', 'p', 'b', 'A', 'E', 'T', 'v', 'S',
        'O', "'", 'k', 'R', 'N', 'L', 'C', 'H', ';', 'W', 'M', 'B', 'D', 'U', 'F',
        'G', 'P', '?', 'Y', '!', '-', 'K', 'x', 'V', 'j', 'q', '[', ']', 'J', ':',
        'Q', 'z', '9', '1', '(', ')', 'Z', 'X', '<', '"', '>', '2', '3', '0', '4',
        '5', '_', '6', '7', '8', '|', '&', '}', '`'
    ]

    vocab_dict = {char: idx for idx, char in enumerate(custom_chars)}

    vocab_dict["<pad>"] = len(vocab_dict)
    vocab_dict["<unk>"] = len(vocab_dict)

    with open("char_vocab.json", "w", encoding="utf-8") as f:
        json.dump(vocab_dict, f)

    tokenizer = PreTrainedTokenizerFast(
        tokenizer_object=Tokenizer(models.WordLevel(vocab_dict, unk_token="<unk>")),
        unk_token="<unk>",
        pad_token="<pad>",
    )

    tokenizer._tokenizer.pre_tokenizer = pre_tokenizers.Split("", behavior="isolated")

    encoding = tokenizer(''.join(custom_chars), add_special_tokens=False)
    print(encoding.input_ids)

    id_to_char = {v: k for k, v in vocab_dict.items()}

    def custom_decoder(token_ids):
        return ''.join(id_to_char.get(token_id, '<unk>') for token_id in token_ids)

    tokenizer.custom_decoder = custom_decoder
    print(tokenizer.custom_decoder(encoding.input_ids))

    def tokenize_function(examples):
        return tokenizer(
            examples["text"],
            add_special_tokens=False,
            truncation=True,
            max_length=cfg.model.length,
            return_attention_mask=False,
            return_token_type_ids=False
        )
    

    def group_texts(examples, block_size):
        concatenated = []
        for tokens in examples["input_ids"]:
            concatenated.extend(tokens)
        total_length = (len(concatenated) // block_size) * block_size
        result = {"input_ids": [concatenated[i:i + block_size] for i in range(0, total_length, block_size)]}
        return result

    def get_dataset(file_path, split, cache_dir, block_size):
        dataset = load_dataset("text", data_files={"train": file_path}, cache_dir=cache_dir)["train"]
        
        tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
        
        grouped_dataset = tokenized_dataset.map(lambda x: group_texts(x, block_size), batched=True)
        grouped_dataset = grouped_dataset.flatten()
        return grouped_dataset

    train_set = get_dataset('shakespeare_train.txt', "train", cfg.data.cache_dir, cfg.model.length)
    valid_set = get_dataset('shakespeare_test.txt', "validation", cfg.data.cache_dir, cfg.model.length)
    train_set.set_format(type="torch", columns=["input_ids"])
    valid_set.set_format(type="torch", columns=["input_ids"])

    train_loader = DataLoader(train_set, batch_size=cfg.training.batch_size, shuffle=True)
    valid_loader = DataLoader(valid_set, batch_size=cfg.training.batch_size, shuffle=False)

    for batch in train_loader:
        print("Batch input_ids shape:", len(batch["input_ids"]), "x", len(batch["input_ids"][0]))
        break

    train_iter = iter(train_loader)
    eval_iter = iter(valid_loader)


    # Build one-step training and evaluation functions
    optimize_fn = losses.optimization_manager(cfg)
    train_step_fn = losses.get_step_fn(noise, graph, True, optimize_fn, cfg.training.accum, cfg)
    eval_step_fn = losses.get_step_fn(noise, graph, False, optimize_fn, cfg.training.accum, cfg)


    if cfg.training.snapshot_sampling:
        sampling_shape = (cfg.training.batch_size // (cfg.ngpus * cfg.training.accum), cfg.model.length)
        sampling_fn = sampling.get_sampling_fn(cfg, graph, noise, sampling_shape, sampling_eps, device)

    num_train_steps = cfg.training.n_iters
    mprint(f"Starting training loop at step {initial_step}.")


    ijk = 0
    arr = []
    while ijk<12000:
        ijk +=1
        print(ijk)
        # print(ijk)
        # print(time.time())
        if ijk == 20:
            st = time.time()


        if cfg.data.train != "text8":
            batch = next(train_iter)['input_ids'].to(device)
        else:
            batch = next(train_iter).to(device)
        M_mean = OT_expectation.loss_fn(state['model'], batch, graph)
        arr.append(M_mean)
        print('mean:', np.array(arr).mean())
        print('std:', np.array(arr).std())
    print(time.time()-st)
