import datetime
import os
import os.path
import gc
from itertools import chain

import numpy as np
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F

import data
import losses
import sampling
import graph_lib
import noise_lib
import utils
from model import SEDD
from model.ema import ExponentialMovingAverage
from model import utils as mutils
from transformers import GPT2TokenizerFast, GPT2LMHeadModel
from load_model import load_model
import copy
from ipdb import set_trace as debug

torch.autograd.set_detect_anomaly(True)

# torch.backends.cudnn.benchmark = True
# torch.autograd.set_detect_anomaly(True)

def multistep_generation(student, graph, noise, batch, with_data, steps, device, eps=1e-5, eps_probs = 1e-12):
    if steps < 0:
        assert with_data
        t = (1 - 1e-3) * torch.rand(batch.shape[0], device=device) + 1e-3
        sigma = noise(t)[0]
        
        perturbed_batch = graph.sample_transition(batch, sigma[:, None])
    else:
        timesteps = torch.linspace(1, eps, steps + 1, device=device)
        curr_steps = torch.randint(0, steps, (batch.shape[0],), device=device, dtype=torch.long)
        t = timesteps[curr_steps]
        if with_data:
            sigma = noise(t)[0]
            perturbed_batch = graph.sample_transition(batch, sigma[:, None])
        else:
            # step_size = 1 / steps
            # curr_sigma = noise(t)[0]
            # next_sigma = noise(t - step_size)[0]
            # dsigma = curr_sigma - next_sigma

            # score = score_fn(x, curr_sigma)

            # stag_score = graph.staggered_score(score, dsigma)
            # probs = stag_score * graph.transp_transition(x, dsigma)
            assert False
    
    log_score_fn = mutils.get_score_fn(student, train=True, sampling=False)
    score = log_score_fn(perturbed_batch, sigma).exp()
    stag_score = graph.staggered_score(score, sigma[:, None])
    probs = stag_score * graph.transp_transition(perturbed_batch, sigma)
    # truncate probabilities
    if graph.absorb:
        probs = probs[..., :-1]
    probs = probs.clamp(min=0.0) + eps_probs
    probs = probs / probs.sum(dim=-1, keepdim=True)
    return probs


def setup(rank, world_size, port):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(port)

    # initialize the process group
    dist.init_process_group(
        "nccl", rank=rank, world_size=world_size, timeout=datetime.timedelta(minutes=30)
    )


def cleanup():
    dist.destroy_process_group()


def run_multiprocess(rank, world_size, cfg, port):
    try:
        setup(rank, world_size, port)
        _run(rank, world_size, cfg)
    finally:
        cleanup()


def _run(rank, world_size, cfg):
    torch.cuda.set_device(rank)
    work_dir = cfg.work_dir

    # Create directories for experimental logs
    sample_dir = os.path.join(work_dir, "samples")
    checkpoint_dir = os.path.join(work_dir, "checkpoints")
    checkpoint_meta_dir = os.path.join(work_dir, "checkpoints-meta", "checkpoint.pth")
    if rank == 0:
        utils.makedirs(sample_dir)
        utils.makedirs(checkpoint_dir)
        utils.makedirs(os.path.dirname(checkpoint_meta_dir))

    # logging
    if rank == 0:
        logger = utils.get_logger(os.path.join(work_dir, "logs"))
    def mprint(msg):
        if rank == 0:
            logger.info(msg)

    mprint(work_dir)
    mprint(cfg)
    device = torch.device(f"cuda:{rank}" if torch.cuda.is_available() else "cpu")
    if device.type == "cuda":
        mprint("Found {} CUDA devices.".format(torch.cuda.device_count()))
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            mprint(
                "{} \t Memory: {:.2f}GB".format(
                    props.name, props.total_memory / (1024 ** 3)
                )
            )
    else:
        mprint("WARNING: Using device {}".format(device))
    mprint(f"Found {os.cpu_count()} total number of CPUs.")
    if cfg.distill.is_distill:
        teacher_model, graph, noise = load_model(cfg.distill.distill_model, device)
        fake_model, student = copy.deepcopy(teacher_model), copy.deepcopy(teacher_model)

        teacher_model = DDP(teacher_model, device_ids=[rank], static_graph=True, find_unused_parameters=True)
        fake_model = DDP(fake_model, device_ids=[rank], static_graph=True, find_unused_parameters=True)
        student = DDP(student, device_ids=[rank], static_graph=True, find_unused_parameters=True)

        mutils.freeze_model(teacher_model)

        num_parameters_student = sum(p.numel() for p in student.parameters())
        mprint(f"Number of parameters in the student: {num_parameters_student}")

        num_parameters_fake = sum(p.numel() for p in fake_model.parameters())
        mprint(f"Number of parameters in the fake model: {num_parameters_fake}")

        ema_student = ExponentialMovingAverage(
            student.parameters(), decay=cfg.training.ema)

        ema_fake = ExponentialMovingAverage(
            fake_model.parameters(), decay=cfg.training.ema)
        
        mprint(student)
        mprint(f"EMA student: {ema_student}")

        mprint(fake_model)
        mprint(f"EMA fake: {ema_fake}")

        noise = DDP(noise, device_ids=[rank], static_graph=True)

        optimizer_student = losses.get_optimizer(cfg, chain(student.parameters(), noise.parameters()))
        mprint(f"Optimizer student: {optimizer_student}")
        optimizer_fake = losses.get_optimizer(cfg, chain(fake_model.parameters()))
        mprint(f"Optimizer fake: {optimizer_fake}")
        scaler_student = torch.cuda.amp.GradScaler()
        scaler_fake = torch.cuda.amp.GradScaler()
        mprint(f"Scaler student: {scaler_student}")
        mprint(f"Scaler fake: {scaler_fake}")
        state = dict(optimizer_student=optimizer_student, optimizer_fake=optimizer_fake, 
            scaler_student=scaler_student, scaler_fake=scaler_fake, student=student, 
            fake_model=fake_model, teacher_model=teacher_model, 
            noise=noise, ema_student=ema_student, ema_fake=ema_fake, step=0) 
    else:
        graph = graph_lib.get_graph(cfg, device)

        score_model = SEDD(cfg).to(device)
        score_model = DDP(score_model, device_ids=[rank], static_graph=True, find_unused_parameters=True)

        num_parameters = sum(p.numel() for p in score_model.parameters())
        mprint(f"Number of parameters in the model: {num_parameters}")   

        ema = ExponentialMovingAverage(
            score_model.parameters(), decay=cfg.training.ema)
        mprint(score_model)
        mprint(f"EMA: {ema}")

        # build noise
        noise = noise_lib.get_noise(cfg).to(device)
        noise = DDP(noise, device_ids=[rank], static_graph=True)

        # build optimization state
        optimizer = losses.get_optimizer(cfg, chain(score_model.parameters(), noise.parameters()))
        mprint(f"Optimizer: {optimizer}")
        scaler = torch.cuda.amp.GradScaler()
        mprint(f"Scaler: {scaler}")
        state = dict(optimizer=optimizer, scaler=scaler, model=score_model, noise=noise, ema=ema, step=0)


    # load in state
    state = utils.restore_checkpoint(checkpoint_meta_dir, state, device, cfg.distill.is_distill)
    initial_step = int(state['step'])

    
    # load in tokenizer
    tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')

    # Build data iterators
    train_ds, eval_ds = data.get_dataloaders(cfg)

    # mprint(f"Length of datasets: {len(train_ds)}, {len(eval_ds)}")

    train_iter = iter(train_ds)
    eval_iter = iter(eval_ds)

    # Build one-step training and evaluation functions
    optimize_fn = losses.optimization_manager(cfg)
    train_step_fn = losses.get_step_fn(noise, graph, True, optimize_fn, cfg.training.accum, 
                                        is_distill=cfg.distill.is_distill, 
                                        num_fake_steps=cfg.distill.num_fake_steps, 
                                        reinforce_coef=cfg.distill.regularization.reinforce_coef,
                                        entropy_regularization_coef=cfg.distill.regularization.entropy_coef, 
                                        forward_kl_coef=cfg.distill.regularization.forward_kl_coef,
                                        gumbel_softmax_relaxation=cfg.distill.gumbel_softmax_relaxation)
    eval_step_fn = losses.get_step_fn(noise, graph, False, optimize_fn, cfg.training.accum, 
                                        is_distill=cfg.distill.is_distill, 
                                        num_fake_steps=cfg.distill.num_fake_steps, 
                                        reinforce_coef=cfg.distill.regularization.reinforce_coef,
                                        entropy_regularization_coef=cfg.distill.regularization.entropy_coef, 
                                        forward_kl_coef=cfg.distill.regularization.forward_kl_coef,
                                        gumbel_softmax_relaxation=cfg.distill.gumbel_softmax_relaxation)

    sampling_eps = 1e-5
    if cfg.training.snapshot_sampling:
        sampling_shape = (cfg.training.batch_size // (cfg.ngpus * cfg.training.accum), cfg.model.length)
        assert (cfg.training.batch_size // (cfg.ngpus * cfg.training.accum)) // cfg.eval.perplexity_batch_size > 0
        sampling_fn = sampling.get_sampling_fn(cfg, graph, noise, sampling_shape, sampling_eps, device)

    num_train_steps = cfg.training.n_iters
    mprint(f"Starting training loop at step {initial_step}.")


    while state['step'] < num_train_steps + 1:
        step = state['step']


        if cfg.data.train != "text8":
            batch = next(train_iter)['input_ids'].to(device)
        else:
            batch = next(train_iter).to(device)
        if cfg.distill.is_distill:
            probs_noisy_batch = multistep_generation(student, 
                                graph, noise, batch, 
                                cfg.distill.multistep.with_data, 
                                cfg.distill.multistep.steps, device, 
                                eps=sampling_eps)
            loss_student, loss_fake = train_step_fn(state, probs_noisy_batch, batch)
        else:
            loss = train_step_fn(state, batch)

        # flag to see if there was movement ie a full batch got computed
        if step != state['step']:
            if step % cfg.training.log_freq == 0:
                if cfg.distill.is_distill:
                    dist.all_reduce(loss_student)
                    loss_student /= world_size

                    dist.all_reduce(loss_fake)
                    loss_fake /= world_size

                    mprint("step: %d, training_loss_student: %.5e" % (step, loss_student.item()))
                    mprint("step: %d, training_loss_fake: %.5e" % (step, loss_fake.item()))
                else:
                    dist.all_reduce(loss)
                    loss /= world_size

                    mprint("step: %d, training_loss: %.5e" % (step, loss.item()))
            
            if step % cfg.training.snapshot_freq_for_preemption == 0 and rank == 0:
                utils.save_checkpoint(checkpoint_meta_dir, state, is_distill=cfg.distill.is_distill)

            if step % cfg.training.eval_freq == 0:
                if cfg.data.valid != "text8":
                    eval_batch = next(eval_iter)['input_ids'].to(device)
                else:
                    eval_batch = next(train_iter).to(device)
                if cfg.distill.is_distill:
                    with torch.no_grad():
                        probs_noisy_batch = multistep_generation(student, 
                                    graph, noise, eval_batch, 
                                    cfg.distill.multistep.with_data, 
                                    cfg.distill.multistep.steps, device, 
                                    eps=sampling_eps)
                    eval_loss_student, eval_loss_fake = eval_step_fn(state, probs_noisy_batch, eval_batch)

                    dist.all_reduce(eval_loss_student)
                    eval_loss_student /= world_size

                    dist.all_reduce(eval_loss_fake)
                    eval_loss_fake /= world_size

                    mprint("step: %d, evaluation_loss_student: %.5e" % (step, eval_loss_student.item()))
                    mprint("step: %d, evaluation_loss_fake: %.5e" % (step, eval_loss_fake.item()))
                else:
                    eval_loss = eval_step_fn(state, eval_batch)

                    dist.all_reduce(eval_loss)
                    eval_loss /= world_size

                    mprint("step: %d, evaluation_loss: %.5e" % (step, eval_loss.item()))

            if step % cfg.training.snapshot_freq == 0 or step == num_train_steps:
                # Save the checkpoint.
                save_step = step // cfg.training.snapshot_freq
                if rank == 0:
                    utils.save_checkpoint(os.path.join(
                        checkpoint_dir, f'checkpoint_{save_step}.pth'), state, is_distill=cfg.distill.is_distill)

                # Generate and save samples
                if cfg.training.snapshot_sampling:
                    mprint(f"Generating text at step: {step}")

                    this_sample_dir = os.path.join(sample_dir, "iter_{}".format(step))
                    utils.makedirs(this_sample_dir)
                    if cfg.distill.is_distill:
                        ema_student.store(student.parameters())
                        ema_student.copy_to(student.parameters())
                        sample = sampling_fn(student)
                        ema_student.restore(student.parameters())
                    else:
                        ema.store(score_model.parameters())
                        ema.copy_to(score_model.parameters())
                        sample = sampling_fn(score_model)
                        ema.restore(score_model.parameters())

                    sentences = tokenizer.batch_decode(sample)
                    
                    file_name = os.path.join(this_sample_dir, f"sample_{rank}.txt")
                    with open(file_name, 'w') as file:
                        for sentence in sentences:
                            file.write(sentence + "\n")
                            file.write("============================================================================================\n")

                    if cfg.eval.perplexity:
                        with torch.no_grad():
                            eval_model = GPT2LMHeadModel.from_pretrained("gpt2-large").to(device).eval()
                            batches = sample.shape[0] // cfg.eval.perplexity_batch_size
                            total_perplexity = 0
                            for i in range(batches):
                                s = sample[i * cfg.eval.perplexity_batch_size:(i + 1) * cfg.eval.perplexity_batch_size]
                                loss, logits = eval_model(s, labels=s)[:2]
                                logits = logits.transpose(-1, -2)
                                perplexity = F.cross_entropy(logits[..., :-1], s[..., 1:], reduction="none").mean(dim=-1).exp().mean()
                                total_perplexity += perplexity
                            total_perplexity /= batches
                            dist.all_reduce(total_perplexity)
                            total_perplexity /= world_size
                            mprint(f"Generative Perplexity at step: {step}. Perplexity: {total_perplexity:.3f}.")

                            del eval_model, logits, loss
                    
                    if cfg.eval.entropy:
                        total_entropy = torch.tensor(0.0, device=device) 
                        for i in sample:
                            _, counts = torch.unique(i, return_counts=True, sorted=False)
                            entropy = torch.special.entr(counts.float() / counts.sum()).sum().item()
                            total_entropy += entropy
                        total_entropy /= sample.shape[0]
                        dist.all_reduce(total_entropy)
                        total_entropy /= world_size
                        mprint(f"Entropy at step: {step}. Entropy: {total_entropy:.3f}.")

                    dist.barrier()
