import datetime
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import os.path
import gc
from itertools import chain

import numpy as np
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F

import data
import losses
import sampling
import graph_lib
import noise_lib
import utils
from model import SEDD
from model.ema import ExponentialMovingAverage
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
import ot
from model import utils as mutils
import time


torch.backends.cudnn.benchmark = True
# torch.autograd.set_detect_anomaly(True)


def setup(rank, world_size, port):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(port)

    # initialize the process group
    dist.init_process_group(
        "nccl", rank=rank, world_size=world_size, timeout=datetime.timedelta(minutes=30)
    )


def cleanup():
    dist.destroy_process_group()


def run_multiprocess(rank, world_size, cfg, port):
    try:
        setup(rank, world_size, port)
        _run(rank, world_size, cfg)
    finally:
        cleanup()


def _run(rank, world_size, cfg):
    torch.cuda.set_device(rank)
    work_dir = cfg.work_dir

    # Create directories for experimental logs
    sample_dir = os.path.join(work_dir, "samples")
    checkpoint_dir = os.path.join(work_dir, "checkpoints")
    checkpoint_meta_dir = os.path.join(work_dir, "checkpoints-meta", "checkpoint.pth")
    if rank == 0:
        utils.makedirs(sample_dir)
        utils.makedirs(checkpoint_dir)
        utils.makedirs(os.path.dirname(checkpoint_meta_dir))

    # logging
    if rank == 0:
        logger = utils.get_logger(os.path.join(work_dir, "logs"))
    def mprint(msg):
        if rank == 0:
            logger.info(msg)

    mprint(work_dir)
    mprint(cfg)
    device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
    if device.type == "cuda":
        mprint("Found {} CUDA devices.".format(torch.cuda.device_count()))
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            mprint(
                "{} \t Memory: {:.2f}GB".format(
                    props.name, props.total_memory / (1024 ** 3)
                )
            )
    else:
        mprint("WARNING: Using device {}".format(device))
    mprint(f"Found {os.cpu_count()} total number of CPUs.")

    # build token graph
    graph = graph_lib.get_graph(cfg, device)
    
    # build score model
    score_model = SEDD(cfg).to(device)
    score_model = DDP(score_model, device_ids=[rank], static_graph=True, find_unused_parameters=True)

    num_parameters = sum(p.numel() for p in score_model.parameters())
    mprint(f"Number of parameters in the model: {num_parameters}")

    ema = ExponentialMovingAverage(
        score_model.parameters(), decay=cfg.training.ema)
    mprint(score_model)
    mprint(f"EMA: {ema}")

    # build noise
    noise = noise_lib.get_noise(cfg).to(device)
    noise = DDP(noise, device_ids=[rank], static_graph=True)
    sampling_eps = 1e-5


    # build optimization state
    optimizer = losses.get_optimizer(cfg, chain(score_model.parameters(), noise.parameters()))
    mprint(f"Optimizer: {optimizer}")
    scaler = torch.cuda.amp.GradScaler()
    mprint(f"Scaler: {scaler}")
    state = dict(optimizer=optimizer, scaler=scaler, model=score_model, noise=noise, ema=ema, step=0) 


    # load in state
    state = utils.restore_checkpoint(checkpoint_meta_dir, state, device)
    initial_step = int(state['step'])

    
    # load in tokenizer
    tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')

    # Build data iterators
    train_ds, eval_ds = data.get_dataloaders(cfg)

    # mprint(f"Length of datasets: {len(train_ds)}, {len(eval_ds)}")

    train_iter = iter(train_ds)
    eval_iter = iter(eval_ds)

    # Build one-step training and evaluation functions
    optimize_fn = losses.optimization_manager(cfg)
    train_step_fn = losses.get_step_fn(noise, graph, True, optimize_fn, cfg.training.accum, cfg)
    eval_step_fn = losses.get_step_fn(noise, graph, False, optimize_fn, cfg.training.accum, cfg)


    if cfg.training.snapshot_sampling:
        sampling_shape = (cfg.training.batch_size // (cfg.ngpus * cfg.training.accum), cfg.model.length)
        sampling_fn = sampling.get_sampling_fn(cfg, graph, noise, sampling_shape, sampling_eps, device)

    num_train_steps = cfg.training.n_iters
    mprint(f"Starting training loop at step {initial_step}.")

    first_flag = True
    while state['step'] < num_train_steps + 1:
        


        if cfg.data.train != "text8":
            big_batch = next(train_iter)['input_ids'].to(device)
        else:
            big_batch = next(train_iter).to(device)

        B, L = big_batch.shape
        sourcebatch = graph.sample_limit(big_batch.shape).to(big_batch)

        use_optimal_transport=cfg.training.use_optimal_transport

        
        if use_optimal_transport:
            try:
                
                with torch.inference_mode():

                    B, L = sourcebatch.shape
                    device = sourcebatch.device

                    # 1) Embed and cache the reference batch (databatch)
                    databatch_emb   = score_model.module.vocab_embed(big_batch).reshape(B, -1)   # (B, D)
                    databatch_norm  = databatch_emb.pow(2).sum(1).view(1, -1)              # (1, B)

                    # 2) Allocate the cost matrix once
                    M = torch.empty(B, B, dtype=torch.float32, device=device)

                    chunk = cfg.training.chunksize

                    # 3) Stream over sourcebatch
                    for s in range(0, B, chunk):
                        e = min(s + chunk, B)

                        src_emb  = score_model.module.vocab_embed(sourcebatch[s:e]).reshape(e - s, -1)  # (c, D)
                        src_norm = src_emb.pow(2).sum(1).view(-1, 1)                              # (c, 1)
                        C = src_norm + databatch_norm                       # (c, B)

                        # M_chunk written in-place
                        torch.addmm(C, src_emb, databatch_emb.T,
                                    beta=1.0, alpha=-2.0, out=M[s:e])

                        del src_emb, src_norm


                    # 3) Uniform marginals
                    a_np = np.full(B, 1.0 / B, dtype=np.float64)
                    b_np = np.full(B, 1.0 / B, dtype=np.float64)
                    C_np = M.double().cpu().numpy() 


                    use_exact = True
                    if use_exact:
                        # Exact discrete OT (plan is ~permutation); 1-1 matching
                        P_np = ot.emd(a_np, b_np, C_np)  # exact plan
                        plan = torch.from_numpy(P_np).to(device=device, dtype=torch.float32)
                        # 1-to-1 assignment from exact plan
                        col_idx = plan.argmax(dim=1)



                # Pair batches according to the sampled indices (source rows stay aligned)
                databatch = big_batch[col_idx]
                # sourcebatch remains as-is (each i paired with databatch[col_idx[i]])
                del M
            except:
                databatch = big_batch
                print('Error calculating optimal transport, continuing with independent sampling this batch')
            #print(sourcebatch[:,:1].flatten())
        else:
            databatch = big_batch
        diffusion_batch_size = cfg.training.diffusion_batch_size
        for diff_batch_no in range(databatch.shape[0]//diffusion_batch_size):
            step = state['step']
            diffusion_source_batch = sourcebatch[diff_batch_no*diffusion_batch_size:(diff_batch_no+1)*diffusion_batch_size]
            diffusion_data_batch = databatch[diff_batch_no*diffusion_batch_size:(diff_batch_no+1)*diffusion_batch_size]
            DiffB = diffusion_data_batch.shape[0]

            for minibatch in range(cfg.training.accum):
                diffusion_source_minibatch = diffusion_source_batch[minibatch*(DiffB//cfg.training.accum):(minibatch+1)*(DiffB//cfg.training.accum)]
                diffusion_data_minibatch = diffusion_data_batch[minibatch*(DiffB//cfg.training.accum):(minibatch+1)*(DiffB//cfg.training.accum)]
                loss = train_step_fn(state, diffusion_source_minibatch, diffusion_data_minibatch)

                # flag to see if there was movement ie a full batch got computed
                if step != state['step']:
                    if step % cfg.training.log_freq == 0:
                        dist.all_reduce(loss)
                        loss /= world_size

                        mprint("step: %d, training_loss: %.5e" % (step, loss.item()))
                    
                    if step % cfg.training.snapshot_freq_for_preemption == 0 and rank == 0:
                        utils.save_checkpoint(checkpoint_meta_dir, state)

















































































