import datetime
import os
import os.path
import gc
from itertools import chain

import numpy as np
import torch
import wandb
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F

import data
import losses
import sampling
import graph_lib
import noise_lib
import utils
from model import SEDDWotSM,SEDD,SEDDWot,SEDDSM
from model.ema import ExponentialMovingAverage
from transformers import GPT2TokenizerFast, GPT2LMHeadModel

from omegaconf import OmegaConf
import torch.nn as nn

torch.backends.cudnn.benchmark = True


# torch.autograd.set_detect_anomaly(True)
class CustomDDP(nn.parallel.DistributedDataParallel):
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return getattr(self.module, name)

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)

def cleanup():
    dist.destroy_process_group()


def run_multiprocess(cfg):
    try:
        dist.init_process_group(backend="nccl")
        _run(local_rank=int(os.environ["LOCAL_RANK"]), gloab_rank=int(os.environ["RANK"]),
             world_size=int(os.environ["WORLD_SIZE"]), cfg=cfg)
    finally:
        cleanup()


def _run(local_rank, gloab_rank, world_size, cfg):
    torch.cuda.set_device(local_rank)
    work_dir = cfg.work_dir

    # Create directories for experimental logs
    sample_dir = os.path.join(work_dir, "samples")
    checkpoint_dir = os.path.join(work_dir, "checkpoints")
    checkpoint_meta_dir = os.path.join(work_dir, "checkpoints-meta", "checkpoint.pth")
    if gloab_rank == 0:
        utils.makedirs(sample_dir)
        utils.makedirs(checkpoint_dir)
        utils.makedirs(os.path.dirname(checkpoint_meta_dir))
        wandb.init(dir=os.path.abspath(work_dir), project='sedd', config=OmegaConf.to_container(cfg, resolve=True),
                   name=cfg.wandb_name, job_type='train')

    # logging
    if gloab_rank == 0:
        logger = utils.get_logger(os.path.join(work_dir, "logs"))

    def mprint(msg):
        if gloab_rank == 0:
            logger.info(msg)

    mprint(work_dir)
    mprint(cfg)
    device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
    if device.type == "cuda":
        mprint("Found {} CUDA devices.".format(torch.cuda.device_count()))
        for i in range(torch.cuda.device_count()):
            props = torch.cuda.get_device_properties(i)
            mprint(
                "{} \t Memory: {:.2f}GB".format(
                    props.name, props.total_memory / (1024 ** 3)
                )
            )
    else:
        mprint("WARNING: Using device {}".format(device))
    mprint(f"Found {os.cpu_count()} total number of CPUs.")

    # build token graph
    graph = graph_lib.get_graph(cfg, device)
    
    # build score model 
    if cfg.model.remove_time_condition == False:
        if cfg.model.add_softmax == False:
            score_model = SEDD(cfg).to(device)
        else:
            score_model = SEDDSM(cfg).to(device)
    else:
        if cfg.model.add_softmax == False:
            score_model = SEDDWot(cfg).to(device)
        else:
            score_model = SEDDWotSM(cfg).to(device)


    score_model = CustomDDP(score_model, device_ids=[local_rank], static_graph=True, find_unused_parameters=True)

    num_parameters = sum(p.numel() for p in score_model.parameters())
    mprint(f"Number of parameters in the model: {num_parameters}")

    ema = ExponentialMovingAverage(
        score_model.parameters(), decay=cfg.training.ema)
    mprint(score_model)
    mprint(f"EMA: {ema}")

    # build noise
    noise = noise_lib.get_noise(cfg).to(device)
    noise = DDP(noise, device_ids=[local_rank], static_graph=True)
    sampling_eps = 1e-5


    # build optimization state
    optimizer = losses.get_optimizer(cfg, chain(score_model.parameters(), noise.parameters()))
    mprint(f"Optimizer: {optimizer}")
    scaler = torch.cuda.amp.GradScaler()
    mprint(f"Scaler: {scaler}")
    state = dict(optimizer=optimizer, scaler=scaler, model=score_model, noise=noise, ema=ema, step=0) 


    # load in state
    state = utils.restore_checkpoint(checkpoint_meta_dir, state, device)
    initial_step = int(state['step'])

    
    # load in tokenizer
    tokenizer = GPT2TokenizerFast.from_pretrained(cfg.gpt_dir) 
    # Build data iterators
    train_ds, eval_ds = data.get_dataloaders(cfg)

    # mprint(f"Length of datasets: {len(train_ds)}, {len(eval_ds)}")

    train_iter = iter(train_ds)
    eval_iter = iter(eval_ds)

    # Build one-step training and evaluation functions
    optimize_fn = losses.optimization_manager(cfg)
    if cfg.training.loss_type == 'cross_entropy':
        assert cfg.model.add_softmax == True

    train_step_fn = losses.get_step_fn(noise, graph, True, optimize_fn, cfg.training.accum,cfg.training.loss_type)
    eval_step_fn = losses.get_step_fn(noise, graph, False, optimize_fn, cfg.training.accum,cfg.training.loss_type)


    if cfg.training.snapshot_sampling:
        sampling_shape = (cfg.training.batch_size // (2 * cfg.ngpus * cfg.training.accum), cfg.model.length)
        sampling_fn = sampling.get_sampling_fn(cfg, graph, noise, sampling_shape, sampling_eps, device)

    num_train_steps = cfg.training.n_iters
    mprint(f"Starting training loop at step {initial_step}.")

    def log_and_wandb(matric, step, name="train_loss"):
        dist.all_reduce(matric)
        matric /= world_size

        mprint(f"step: %d, {name}: %.5e" % (step, matric.item()))
        if gloab_rank == 0:
            wandb.log({name: matric.item()}, step=step)

    while state['step'] < num_train_steps + 1:
        step = state['step']


        if cfg.train_set.name != "text8":
            batch = next(train_iter)['input_ids'].to(device)
        else:
            batch = next(train_iter).to(device)
        loss = train_step_fn(state, batch)

        # flag to see if there was movement ie a full batch got computed
        if step != state['step']:
            if step % cfg.training.log_freq == 0:
                log_and_wandb(loss, step, 'train_loss')

            if step % cfg.training.snapshot_freq_for_preemption == 0 and gloab_rank == 0:
                utils.save_checkpoint(checkpoint_meta_dir, state)

            if step % cfg.training.eval_freq == 0:
                if cfg.valid_set.name != "text8":
                    eval_batch = next(eval_iter)['input_ids'].to(device)
                else:
                    eval_batch = next(train_iter).to(device)
                eval_loss = eval_step_fn(state, eval_batch)

                dist.all_reduce(eval_loss)
                eval_loss /= world_size

                mprint("step: %d, evaluation_loss: %.5e" % (step, eval_loss.item()))
                if gloab_rank == 0:
                    wandb.log({'eval_loss': eval_loss.item()}, step=step)

            if step > 0 and step % cfg.training.snapshot_freq == 0 or step == num_train_steps:
                # Save the checkpoint.
                if gloab_rank == 0:
                    utils.save_checkpoint(os.path.join(
                        checkpoint_dir, f'checkpoint_{step}.pth'), state)

                # Generate and save samples
                if cfg.training.snapshot_sampling:
                    del loss, batch
                    torch.cuda.empty_cache()
                    mprint(f"Generating text at step: {step}")

                    this_sample_dir = os.path.join(sample_dir, "iter_{}".format(step))
                    utils.makedirs(this_sample_dir)

                    ema.store(score_model.parameters())
                    ema.copy_to(score_model.parameters())
                    sample = sampling_fn(score_model)
                    ema.restore(score_model.parameters())

                    sentences = tokenizer.batch_decode(sample)

                    file_name = os.path.join(this_sample_dir, f"sample_{gloab_rank}.txt")
                    with open(file_name, 'w') as file:
                        for sentence in sentences:
                            file.write(sentence + "\n")
                            file.write("============================================================================================\n")

                    if cfg.eval.perplexity:
                        with torch.no_grad():
                            eval_model = GPT2LMHeadModel.from_pretrained(cfg.gpt_dir).to(device).eval()
                            batches = sample.shape[0] // cfg.eval.perplexity_batch_size
                            total_perplexity = 0
                            for i in range(batches):
                                s = sample[i * cfg.eval.perplexity_batch_size:(i + 1) * cfg.eval.perplexity_batch_size]
                                loss, logits = eval_model(s, labels=s)[:2]
                                logits = logits.transpose(-1, -2)
                                perplexity = F.cross_entropy(logits[..., :-1], s[..., 1:], reduction="none").mean(dim=-1).exp().mean()
                                total_perplexity += perplexity
                            total_perplexity /= batches
                            dist.all_reduce(total_perplexity)
                            total_perplexity /= world_size
                            mprint(f"Generative Perplexity at step: {step}. Perplexity: {total_perplexity:.3f}.")
                            if gloab_rank == 0:
                                wandb.log({'GPT2 perplexity': total_perplexity}, step=step)

                            del eval_model, logits, loss

                    dist.barrier()
