import datetime
import os
import yaml
import time
import math
import torch
import numpy as np
import seaborn as sns
from torch.utils.data import DataLoader, RandomSampler, BatchSampler
import matplotlib.pyplot as plt
from loguru import logger
import psutil
from cfg import build_cfg,save_all_hparams
import torch.nn.functional as F
from contextlib import nullcontext
from model import DaoConfig, GeST
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group,all_reduce
from dataset import MultiSpatialTarget,SpatialTarget
import scanpy as sc
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.distributed import DistributedSampler

import wandb


def collate_fn(batch):
    exp, coord,mask = zip(*batch)
    padded_exp = pad_sequence(exp, batch_first=True)
    padded_coord = pad_sequence(coord, batch_first=True)
    padded_mask = pad_sequence(mask, batch_first=True)
    return padded_exp, padded_coord, padded_mask


def run():
    # 1. load config, init DDP
    # 2. set random seed
    # 3. set device
    # 4. load dataset
    # 5. init model
    # 6. init optimizer
    # 7. init lr scheduler
    # 8. training loop
    cfg = build_cfg()
    torch.cuda.init()
    # saving and logging
    timestamp = time.strftime("%Y-%m-%d", time.localtime())
    out_dir = os.path.join(cfg.output_dir, cfg.comment+'_'+timestamp)
    logger.add(os.path.join(out_dir, 'dao_train.log'), level='INFO')
    if 'LOCAL_RANK' not in os.environ or os.environ['LOCAL_RANK'] == '0':
        logger.info("Build cfg success!")
        logger.info(os.path.join(out_dir, 'dao_train.log'))
        logger.info("Saving to {}".format(out_dir))
        ## print cfg as a table
        logger.info(f"Config:{cfg}")
    ckpt_save_dir = os.path.join(out_dir, 'ckpt')

    # -----------------------------------------------------------------------------
    # loading config
    backend = cfg.backend
    compile = cfg.compile # Default True, use PyTorch 2.0 to compile the model to be faster
    gradient_accumulation_steps = cfg.gradient_accumulation_steps
    batch_size = cfg.batch_size
    block_size = cfg.block_size

    task = cfg.task
    # crop_train_length = cfg.block_size
    n_layer = cfg.n_layer
    n_head = cfg.n_head
    n_embd = cfg.n_embd
    bias = cfg.bias
    dropout = cfg.dropout
    train_mode = cfg.train_mode
    init_from = cfg.init_from
    device_type = cfg.device
    dtype = cfg.dtype
    data_path = cfg.data_path
    vocab_size = cfg.vocab_size
    loss_len = cfg.loss_len
    modeltype = cfg.model
    # -----------------------------------------------------------------------------

    ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run?
    if ddp:
        init_process_group(backend=backend,timeout=datetime.timedelta(seconds=2000))
        ddp_rank = int(os.environ['RANK'])
        ddp_local_rank = int(os.environ['LOCAL_RANK'])
        ddp_world_size = int(os.environ['WORLD_SIZE'])
        device = f'cuda:{ddp_local_rank}'
        torch.cuda.set_device(device)
        master_process = ddp_rank == 0 # this process will do logging, checkpointing etc.
        seed_offset = ddp_rank # each process gets a different seed
        assert gradient_accumulation_steps % torch.cuda.device_count() == 0
        gradient_accumulation_steps //= torch.cuda.device_count()

        torch.manual_seed(19491001 + seed_offset)
    else:
        master_process=True
        torch.manual_seed(19491001)

    setattr(cfg, 'master_process', master_process)

    codebook = None
    if cfg.codebook != 'none' and cfg.task == 'quantize':
        codebook = np.load(cfg.codebook)
        import pandas as pd
        basedir = cfg.codebook.split('codebook')[0]
        pca = pd.read_csv(f'{basedir}meta_cells_pca.csv',index_col=0).values
        from sklearn.neighbors import NearestNeighbors
        nbrs = NearestNeighbors(n_neighbors=1, algorithm='auto').fit(pca)
        pcaproj = np.load(f'{basedir}PCs.npy')
        mean = np.load(f'{basedir}mean.npy')
        if master_process:
            logger.info(f'codebook loaded from {cfg.codebook} with pca projection and mean shape {codebook.shape} other loaded from {basedir}')
        codebook = ((nbrs,mean,pcaproj),codebook)

    if os.path.isdir(data_path):
        adata = sc.read_h5ad(os.path.join(data_path,'train.h5ad'))
        val_celladata = sc.read_h5ad(os.path.join(data_path,'valid.h5ad')) 
        cell_ds = SpatialTarget(adata,cfg=cfg,shuffle=cfg.idxshuffle,codebook=codebook)
        val_cell_ds = SpatialTarget(val_celladata,cfg=cfg,shuffle=cfg.idxshuffle,codebook=codebook)
    else:
        with open(data_path,'r') as f:
            d = yaml.safe_load(f)
        cell_ds = MultiSpatialTarget(d['train'],cfg=cfg,shuffle=cfg.idxshuffle,codebook=codebook)
        val_cell_ds = MultiSpatialTarget(d['valid'],cfg=cfg,shuffle=cfg.idxshuffle,codebook=codebook)
        for k,v in d.items():
            setattr(cfg, 'dataset_'+k, v)    

    if master_process:
        os.makedirs(ckpt_save_dir, exist_ok=True)
        save_all_hparams(out_dir,cfg)
    
    if ddp:
        train_sampler = DistributedSampler(cell_ds)
        val_sampler = DistributedSampler(val_cell_ds)

        loader = DataLoader(cell_ds, batch_size=batch_size,pin_memory=True,num_workers=2,collate_fn=collate_fn,sampler=train_sampler, prefetch_factor=2)
        valloader = DataLoader(val_cell_ds, batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=1,collate_fn=collate_fn,sampler=val_sampler, prefetch_factor=2)
    else:
        loader = DataLoader(cell_ds, batch_size=batch_size,pin_memory=True,num_workers=2,collate_fn=collate_fn,shuffle=True, prefetch_factor=2)
        valloader = DataLoader(val_cell_ds, batch_size=batch_size,shuffle=False,pin_memory=True,num_workers=1,collate_fn=collate_fn, prefetch_factor=2)
    torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
    torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
    ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
    ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

    best_val_loss = 1e9

    # model init
    model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size,batch_size=batch_size,
                    bias=bias,dropout=dropout,train_mode=train_mode,task=task,vocab_size=vocab_size,loss_len=loss_len,
                    encoder = cfg.encoder, decoder = cfg.decoder, 
                    skipconnect=cfg.skipconnect, noise = cfg.noise, rope_base = cfg.rope_base,loc_emb = cfg.loc_emb,device_type=device_type,modeltype=modeltype,
                    codebook = cfg.codebook) # start with model_args from command line

    gptconf = DaoConfig(**model_args)
    if modeltype.__contains__('GeST'):
        model = GeST(gptconf)
    else:
        exit()
    # optimizer
    model.to(device_type)

    if cfg.encoder == 'scimilarity':
        if init_from[:5] == 'preAE':
            model.inlinear.load_state_pretrained(cfg.init_from[5:],use_gpu=True)
        else:
            model.inlinear.load_state('/nfs/public/usr/Sgeneration/scimilarity/annotation/encoder.ckpt',use_gpu=True)

    elif cfg.encoder == 'multimlp':
        if init_from[:5] == 'preAE':
            ckpt  = torch.load(cfg.init_from[5:], map_location=device_type)['model']
            exp_weights = {k.split('inlinear.')[1]: v for k, v in ckpt.items() if k.startswith('inlinear')}
            model.inlinear.load_state_dict(exp_weights)

    if cfg.decoder == 'scimilarity':
        if init_from[:5] == 'preAE':
            model.epx_head.load_state_pretrained(init_from[5:],use_gpu=True)
        else:
            model.epx_head.load_state('/nfs/public/usr/Sgeneration/scimilarity/annotation/decoder.ckpt',use_gpu=True)
    elif cfg.decoder == 'multimlp':
        if init_from[:5] == 'preAE':
            ckpt  = torch.load(cfg.init_from[5:], map_location=device_type)['model']
            exp_weights = {k.split('epx_head.')[1]: v for k, v in ckpt.items() if k.startswith('epx_head')}
            model.epx_head.load_state_dict(exp_weights)

    if cfg.train_mode == 'frozenAE':
        for param in model.epx_head.parameters():
            param.requires_grad = False
        for param in model.inlinear.parameters():
            param.requires_grad = False

    if init_from != 'scratch' and init_from[:5]!='preAE':
        if master_process:
            logger.info(f"loading model from {init_from}")
        checkpoint = torch.load(init_from, map_location=device_type)
        model.load_state_dict(checkpoint['model'])
        if cfg.train_mode == 'frozenAE':
            for param in model.epx_head.parameters():
                param.requires_grad = False
            for param in model.inlinear.parameters():
                param.requires_grad = False
        iter_num_save=0
        epoch_save=0
        local_iter_num = 0 # number of iterations in the lifetime of this process
        optimizer = model.configure_optimizers(cfg.weight_decay, float(cfg.learning_rate), (cfg.beta1, cfg.beta2), device_type)
        # initialize a GradScaler. If enabled=False scaler is a no-op
        scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'bfloat16' or dtype == 'float16' ))
        # schedular = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=1000, T_mult=2, eta_min=cfg.min_lr)
        schedular = None

        if train_mode !='finetune':
            optimizer.load_state_dict(checkpoint['optimizer'])
            if checkpoint['schedular'] is not None:
                schedular = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=1000, T_mult=2, eta_min=cfg.min_lr)
                schedular.load_state_dict(checkpoint['schedular'])
            iter_num_save = checkpoint['iter_num']
            best_val_loss = checkpoint['best_val_loss']
            eval_loss = checkpoint['eval_loss']
            epoch_save = checkpoint['epoch']
            local_iter_num = checkpoint['local_iter_num']
            logger.info(f'resume training from {init_from} at iter {iter_num_save} epoch {epoch_save} with best val loss {best_val_loss:.4f} and local iter num {local_iter_num}')
    elif init_from[:5]=='preAE':
        if master_process:
            logger.info(f"loading pretrained AEmodel from {init_from}")
        iter_num_save=0
        epoch_save=0
        local_iter_num = 0 # number of iterations in the lifetime of this process
        optimizer = model.configure_optimizers(cfg.weight_decay, float(cfg.learning_rate), (cfg.beta1, cfg.beta2), device_type)

    else:
        if master_process:
            logger.info("Initializing a new model from scratch")
        iter_num_save=0
        epoch_save=0
        local_iter_num = 0 # number of iterations in the lifetime of this process
        optimizer = model.configure_optimizers(cfg.weight_decay, float(cfg.learning_rate), (cfg.beta1, cfg.beta2), device_type)
        # initialize a GradScaler. If enabled=False scaler is a no-op
        scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'bfloat16' or dtype == 'float16' ))
        # schedular = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,T_0=1000, T_mult=2, eta_min=cfg.min_lr)
        schedular = None

    if compile:
        if master_process:
            logger.info("compiling the model... (takes a ~minute)")
        model = torch.compile(model) # requires PyTorch 2.0

    # wrap model into DDP container
    if ddp:
        model = DDP(model, device_ids=[ddp_local_rank])

    # helps estimate an arbitrarily accurate loss over either split 
    #  using many batches
    @torch.no_grad()
    def estimate_loss(loader,master_process,cfg):
        out = {}
        model.eval()
        # context = torch.tensor([encode(cellTypeList[torch.randint(len(cellTypeList),(1,)).item()])],dtype=torch.long, device=device)
        # logger.info(decode(model.generate(context, max_new_tokens=100)[0].tolist()))
        losses = []
        for iter_num, (exp,coord,mask) in enumerate(loader):
            if exp.shape[1]>cfg.block_size:
                exp = exp[:,:cfg.block_size]
                coord = coord[:,:cfg.block_size]
                mask = mask[:,:cfg.block_size]
            else:
                # Assuming exp, coord, and mask are defined tensors, and cfg.block_size is the target padding size
                # Calculate padding sizes for exp and coord
                padding_exp = (0, 0, 0, cfg.block_size - exp.shape[1])
                padding_coord = (0, 0, 0, cfg.block_size - coord.shape[1])

                # Pad exp and coord tensors
                exp_padded = F.pad(exp, padding_exp, "constant", 0)
                coord_padded = F.pad(coord, padding_coord, "constant", 0)

                # For mask, pad it and then convert to boolean
                mask_padded = F.pad(mask, (0, cfg.block_size - mask.shape[1]), "constant", 0).bool()

                # Now exp_padded, coord_padded, and mask_padded are your new tensors
                exp = exp_padded
                coord = coord_padded
                mask = mask_padded
            exp,coord,mask = exp.to(device_type).float(), coord.to(device_type).float(), mask.to(device_type)
            torch.cuda.empty_cache()
            with torch.no_grad():
                with ctx:
                    _, loss = model(inputs_embeds=exp,coord=coord,mask=mask)
                losses.append(loss.item())
        # all reduce mean
        losses = torch.tensor(losses).nanmean().to(device_type)
        all_reduce(losses, op=torch.distributed.ReduceOp.AVG)
        if master_process:
            out['val'] = losses.item()
        model.train()
        return out

    # learning rate decay scheduler (cosine with warmup) TODO investigate
    def get_lr(it,schedular=None):
        # 1) linear warmup for warmup_iters steps
        if it < cfg.warmup_iters:
            return cfg.learning_rate * it / cfg.warmup_iters
        # 2) if it > lr_decay_iters, return min learning rate
        if it > cfg.lr_decay_iters:
            return cfg.min_lr
        # 3) in between, use cosine decay down to min learning rate
        if schedular is not None:
            schedular.step()
            return schedular.get_last_lr()[0]
        else:
            decay_ratio = (it - cfg.warmup_iters) / (cfg.lr_decay_iters - cfg.warmup_iters)
            assert 0 <= decay_ratio <= 1
            coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
            return cfg.min_lr + coeff * (cfg.learning_rate - cfg.min_lr)

    # training loop
    iter_num_epoch = len(loader)
    logger.info(f'iter num per epoch {iter_num_epoch} with batch acc iter num {iter_num_epoch// gradient_accumulation_steps}')
    t0 = time.time()
    raw_model = model.module if ddp else model # unwrap DDP container if needed
    running_mfu = -1.0

    if master_process:
        logger.info('--------------training start---------------------')
        eval_loss = np.array([], dtype=np.float32)
        # start a new wandb run to track this script
        wandb.init(
            # set the wandb project where this run will be logged
            project="SpatialGeneration",
            name = f"{cfg.comment}_{time.strftime('%Y-%m-%d-%H-%M', time.localtime())}",
            # track hyperparameters and run metadata
            config=cfg
        )
    t1 = time.time()
    dt = t1 - t0

    seqlen = []

    model.train()
    for ep in range(epoch_save,cfg.epoch):
        if ddp:
            train_sampler.set_epoch(ep)
        for iter_num, (exp,coord,mask) in enumerate(loader):
            if iter_num < iter_num_save:
                continue
            # logger.debug(f'before norm input shape {exp.shape}')
            if exp.shape[1]>cfg.block_size:
                exp = exp[:,:cfg.block_size]
                coord = coord[:,:cfg.block_size]
                mask = mask[:,:cfg.block_size]
            else:
                # Assuming exp, coord, and mask are defined tensors, and cfg.block_size is the target padding size
                # Calculate padding sizes for exp and coord
                padding_exp = (0, 0, 0, cfg.block_size - exp.shape[1])
                padding_coord = (0, 0, 0, cfg.block_size - coord.shape[1])

                # Pad exp and coord tensors
                exp_padded = F.pad(exp, padding_exp, "constant", 0)
                coord_padded = F.pad(coord, padding_coord, "constant", 0)

                # For mask, pad it and then convert to boolean
                mask_padded = F.pad(mask, (0, cfg.block_size - mask.shape[1]), "constant", 0).bool()
                # Now exp_padded, coord_padded, and mask_padded are your new tensors
                exp = exp_padded
                coord = coord_padded
                mask = mask_padded
            seqlen.append(mask.float().sum(-1).mode()[0].item())
            # logger.debug(f'after norm input shape {exp.shape}')
            logger.debug(f"SeqLEN {torch.unique(mask.float().sum(-1),return_counts=True)}, {seqlen}")

            exp,coord,mask = exp.to(device_type).float(), coord.to(device_type).float(), mask.to(device_type)
            lr = get_lr(local_iter_num,schedular) if cfg.decay_lr else cfg.learning_rate
            logger.debug(f"lr {lr}")
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr

            if local_iter_num % cfg.eval_itervals == 0 and local_iter_num != 0 and iter_num % gradient_accumulation_steps == 0:
                losses = estimate_loss(valloader,master_process,cfg=cfg)
                if master_process:
                    logger.info(f"epoch {ep} step {local_iter_num}: val loss {losses['val']:.4f}")
                    wandb.log({"val loss": losses['val']},step=(iter_num+ep*iter_num_epoch)//gradient_accumulation_steps)
                    if losses['val']<best_val_loss:
                        best_val_loss = losses['val']
                        checkpoint = {
                                'model': raw_model.state_dict(),
                                'optimizer': optimizer.state_dict(),
                                'model_args': model_args,
                                'iter_num': iter_num,
                                'local_iter_num':local_iter_num,
                                'best_val_loss': best_val_loss,
                                'eval_loss': eval_loss,
                                'epoch': ep,
                                'schedular': schedular.state_dict() if schedular is not None else None
                            }
                        logger.info(f"saving best checkpoint to {os.path.join(ckpt_save_dir, f'ckpt_best.pt')}")
                        torch.save(checkpoint,os.path.join(ckpt_save_dir, f'ckpt_best.pt'))
                    if (local_iter_num % cfg.always_save_checkpoint==0) :
                        if local_iter_num > 0:
                            checkpoint = {
                                'model': raw_model.state_dict(),
                                'optimizer': optimizer.state_dict(),
                                'model_args': model_args,
                                'iter_num': iter_num,
                                'local_iter_num':local_iter_num,
                                'best_val_loss': best_val_loss,
                                'eval_loss': eval_loss,
                                'epoch': ep,
                                'schedular': schedular.state_dict() if schedular is not None else None
                            }
                            logger.info(f"saving checkpoint to {os.path.join(ckpt_save_dir, f'ckpt_epoch{ep}_{iter_num//gradient_accumulation_steps}.pt')}")
                            torch.save(checkpoint,os.path.join(ckpt_save_dir, f'ckpt_epoch{ep}_{iter_num//gradient_accumulation_steps}.pt'))
                    eval_loss = np.append(eval_loss,losses['val'])
            
            if ddp:
                model.require_backward_grad_sync = ( (iter_num+1) % gradient_accumulation_steps ==0)

            t_forward_s = time.time()
            with ctx:
                _, loss = model(inputs_embeds=exp,coord=coord,mask=mask)
                loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation
            t_forward_e = time.time()
            t_forward_d = t_forward_e - t_forward_s
            # immediately async prefetch next batch while model is doing the forward pass on the GPU
            # torch.cuda.empty_cache()
            # t_getbatch_e = time.time()
            # t_getbatch_d = t_getbatch_e - t_forward_e
            scaler.scale(loss).backward()
            t_back_e = time.time()
            t_back_d = t_back_e - t_forward_e

            if (iter_num+1) % gradient_accumulation_steps == 0:
                # clip the gradient
                if cfg.grad_clip != 0.0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
                # step the optimizer and scaler if training in fp16
                scaler.step(optimizer)
                scaler.update()
                # flush the gradients as soon as we can, no need for this memory anymore
                optimizer.zero_grad(set_to_none=True)
                local_iter_num += 1
                # timing and logging
                t1 = time.time()
                dt = t1 - t0
                t0 = t1

                logger.debug(f"*******forward time {t_forward_d:.4f}")
                logger.debug(f"*******back time {t_back_d:.4f}")
                if ( (iter_num+1)// gradient_accumulation_steps) % cfg.log_interval == 0 and master_process:
                    lossf = loss.item() * gradient_accumulation_steps
                    if local_iter_num >= 5: # let the training loop settle a bit
                        mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt)
                        # mfu = raw_model.estimate_mfu(cfg.batch_size, dt)
                        running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu
                    logger.info(f"epoch {ep}, iter {(iter_num+1)// gradient_accumulation_steps}/{iter_num_epoch // gradient_accumulation_steps}: loss {lossf:.4f}, time {dt:.2f}s, seqlen {np.mean(seqlen)}, mfu {running_mfu*100:.2f}%")
                    wandb.log({"loss": lossf,"time": dt,"lr":lr,"epoch":ep,"seqlen":np.mean(seqlen)},step=(iter_num+ep*iter_num_epoch)//gradient_accumulation_steps)

                    seqlen = []

            # termination conditions
            if local_iter_num > cfg.max_iters:
                break
    if ddp:
        if master_process:
            wandb.finish()
        destroy_process_group()    
    
if __name__ == "__main__":
    run()