import datetime
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import os.path
import gc
from itertools import chain

import numpy as np
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F

import data
import losses
import sampling
import graph_lib
import noise_lib
import utils
from model import SEDD
from model.ema import ExponentialMovingAverage
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
import ot
from model import utils as mutils
import time


torch.backends.cudnn.benchmark = True
# torch.autograd.set_detect_anomaly(True)


def setup(rank, world_size, port):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(port)

    # initialize the process group
    dist.init_process_group(
        "nccl", rank=rank, world_size=world_size, timeout=datetime.timedelta(minutes=30)
    )


def cleanup():
    dist.destroy_process_group()


def run_multiprocess(rank, world_size, cfg, port):
    try:
        setup(rank, world_size, port)
        _run(rank, world_size, cfg)
    finally:
        cleanup()


def _run(rank, world_size, cfg):
    torch.cuda.set_device(rank)
    work_dir = cfg.work_dir

    # Create directories for experimental logs
    sample_dir = os.path.join(work_dir, "samples")
    checkpoint_dir = os.path.join(work_dir, "checkpoints")
    checkpoint_meta_dir = os.path.join(work_dir, "checkpoints-meta", "checkpoint.pth")
    if rank == 0:
        utils.makedirs(sample_dir)
        utils.makedirs(checkpoint_dir)
        utils.makedirs(os.path.dirname(checkpoint_meta_dir))

    # logging
    if rank == 0:
        logger = utils.get_logger(os.path.join(work_dir, "logs"))
    def mprint(msg):
        if rank == 0:
            logger.info(msg)

    mprint(work_dir)
    mprint(cfg)
    device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
    if device.type == "cuda":
        mprint("Found {} CUDA devices.".format(torch.cuda.device_count()))
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            mprint(
                "{} \t Memory: {:.2f}GB".format(
                    props.name, props.total_memory / (1024 ** 3)
                )
            )
    else:
        mprint("WARNING: Using device {}".format(device))
    mprint(f"Found {os.cpu_count()} total number of CPUs.")

    # build token graph
    graph = graph_lib.get_graph(cfg, device)
    
    # build score model
    score_model = SEDD(cfg).to(device)
    score_model = DDP(score_model, device_ids=[rank], static_graph=True, find_unused_parameters=True)

    num_parameters = sum(p.numel() for p in score_model.parameters())
    mprint(f"Number of parameters in the model: {num_parameters}")

    ema = ExponentialMovingAverage(
        score_model.parameters(), decay=cfg.training.ema)
    mprint(score_model)
    mprint(f"EMA: {ema}")

    # build noise
    # noise = noise_lib.get_noise(cfg).to(device)
    # noise = DDP(noise, device_ids=[rank], static_graph=True)
    sampling_eps = 1e-5


    # build optimization state
    optimizer = losses.get_optimizer(cfg, chain(score_model.parameters()))
    mprint(f"Optimizer: {optimizer}")
    scaler = torch.cuda.amp.GradScaler()
    mprint(f"Scaler: {scaler}")
    state = dict(optimizer=optimizer, scaler=scaler, model=score_model, ema=ema, step=0) 


    # load in state
    state = utils.restore_checkpoint(checkpoint_meta_dir, state, device)
    initial_step = int(state['step'])

    
    # load in tokenizer
    tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')

    # Build data iterators
    train_ds, eval_ds = data.get_dataloaders(cfg)

    # mprint(f"Length of datasets: {len(train_ds)}, {len(eval_ds)}")

    train_iter = iter(train_ds)
    eval_iter = iter(eval_ds)

    # Build one-step training and evaluation functions
    optimize_fn = losses.optimization_manager(cfg)
    train_step_fn = losses.get_step_fn(graph, True, optimize_fn, cfg.training.accum, cfg)
    eval_step_fn = losses.get_step_fn(graph, False, optimize_fn, cfg.training.accum, cfg)


    # if cfg.training.snapshot_sampling:
    #     sampling_shape = (cfg.training.batch_size // (cfg.ngpus * cfg.training.accum), cfg.model.length)
    #     sampling_fn = sampling.get_sampling_fn(cfg, graph, noise, sampling_shape, sampling_eps, device)

    num_train_steps = cfg.training.n_iters
    mprint(f"Starting training loop at step {initial_step}.")

    first_flag = True
    while state['step'] < num_train_steps + 1:
        
        if cfg.data.train != "text8":
            big_batch = next(train_iter)['input_ids'].to(device)
        else:
            big_batch = next(train_iter).to(device)
            
        if big_batch.shape[0] != cfg.training.batch_size:
            mprint(f"Skipping incomplete batch (size {big_batch.shape[0]}). Reshuffling...")
            continue

        B, L = big_batch.shape
        sourcebatch = graph.sample_limit(big_batch.shape).to(big_batch)

        use_optimal_transport = cfg.training.use_optimal_transport

        if use_optimal_transport and state['step'] > -1:
            steps_per_iter = cfg.training.batch_size // cfg.training.diffusion_batch_size
            ema = state['ema']
            
            try:
                # Use EMA weights and eval mode for OT calculation
                ema.store(score_model.parameters())
                ema.copy_to(score_model.parameters())
                score_model.eval()

                ot_metric = cfg.training.ot_metric 

                with torch.inference_mode():
                    t_all = torch.zeros(B, device=device)
                    M = torch.empty(B, B, dtype=torch.float32, device=device)
                    raw_model = score_model.module if isinstance(score_model, DDP) else score_model

                    # 1. Explicit Metric Validation and Weight Retrieval
                    if ot_metric == 'input_embedding_l2':
                        # Use vocab input embeddings
                        if hasattr(raw_model.vocab_embed, 'embedding'):
                            W = raw_model.vocab_embed.embedding
                        else:
                            W = raw_model.vocab_embed.weight
                    else:
                        raise NotImplementedError(f"OT metric '{ot_metric}' is not supported. Use 'input_embedding_l2'.")

                    # 2. Pre-compute Target Vectors
                    # (B, L) -> (B, L, Hidden) -> (B, L*Hidden)
                    target_emb_flat = F.embedding(big_batch, W).view(B, -1)

                    # 3. Compute Cost Matrix M in chunks
                    chunk = cfg.training.chunksize
                    for s in range(0, B, chunk):
                        e = min(s + chunk, B)
                        piece = sourcebatch[s:e]
                        
                        if ot_metric == 'input_embedding_l2':
                            # Direct embedding lookup for source tokens
                            model_out = F.embedding(piece, W)

                        # Minimizing L2 is equivalent to maximizing dot product
                        # Cost matrix M stores -dot product for minimization
                        h_flat = model_out.view(model_out.size(0), -1)
                        dot = torch.mm(h_flat, target_emb_flat.t())
                        M[s:e].copy_(-dot)
                        del model_out

                    # 4. Solve Optimal Transport (Exact)
                    a_np = np.full(B, 1.0 / B, dtype=np.float64)
                    b_np = np.full(B, 1.0 / B, dtype=np.float64)
                    C_np = M.double().cpu().numpy() / L 

                    # Double-centering for numerical stability
                    row_mean = C_np.mean(axis=1, keepdims=True)
                    col_mean = C_np.mean(axis=0, keepdims=True)
                    global_mean = C_np.mean()
                    C_res = C_np - row_mean - col_mean + global_mean
                    
                    if C_res.min() < 0:
                        C_res -= C_res.min()

                    P_np = ot.emd(a_np, b_np, C_res, numItermax=100000)
                    plan = torch.from_numpy(P_np).to(device=device, dtype=torch.float32)
                    col_idx = plan.argmax(dim=1)

                databatch = big_batch[col_idx]
                del M

            except NotImplementedError as e:
                # If the metric isn't implemented, we should stop training
                raise e
            except Exception as e:
                # For other runtime errors (CUDA OOM, etc.), fallback to random batching
                databatch = big_batch
                print(f'Error calculating OT: {e}')
            finally:
                # Always restore training weights and mode
                try:
                    ema.restore(score_model.parameters())
                except:
                    pass
                score_model.train()
        else:
            databatch = big_batch

        diffusion_batch_size = cfg.training.diffusion_batch_size
        for diff_batch_no in range(databatch.shape[0]//diffusion_batch_size):
            step = state['step']
            diffusion_source_batch = sourcebatch[diff_batch_no*diffusion_batch_size:(diff_batch_no+1)*diffusion_batch_size]
            diffusion_data_batch = databatch[diff_batch_no*diffusion_batch_size:(diff_batch_no+1)*diffusion_batch_size]
            DiffB = diffusion_data_batch.shape[0]

            for minibatch in range(cfg.training.accum):
                diffusion_source_minibatch = diffusion_source_batch[minibatch*(DiffB//cfg.training.accum):(minibatch+1)*(DiffB//cfg.training.accum)]
                diffusion_data_minibatch = diffusion_data_batch[minibatch*(DiffB//cfg.training.accum):(minibatch+1)*(DiffB//cfg.training.accum)]
                loss = train_step_fn(state, diffusion_source_minibatch, diffusion_data_minibatch)

                # flag to see if there was movement ie a full batch got computed
                if step != state['step']:
                    if step % cfg.training.log_freq == 0:
                        dist.all_reduce(loss)
                        loss /= world_size

                        mprint("step: %d, training_loss: %.5e" % (step, loss.item()))
                    
                    if step % cfg.training.snapshot_freq_for_preemption == 0 and rank == 0:
                        utils.save_checkpoint(checkpoint_meta_dir, state)