import os
import torch
import argparse
import itertools
import numpy as np
from tqdm import tqdm
import torch.optim as optim
from torchvision.utils import save_image

import sys

from CelebaExperiment.embedding import JointCelebaEmbedding
from cfg.Scheduler import GradualWarmupScheduler
from cfg.dataloader_pickle import PickleDataset, transback
from cfg.diffusion import GaussianDiffusion
from cfg.unet import Unet
from cfg.utils import get_named_beta_schedule

# sys.path.append('../retrain_trick'); sys.path.append('../Morpho-MNIST')
#from dataloader_cifar import load_data, transback
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import get_rank, init_process_group, destroy_process_group, all_gather, get_world_size
import torch.nn.functional as F
from torch import Tensor
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler





parser = argparse.ArgumentParser(description='test for diffusion model')
parser.add_argument('--exp_name', type=str, default="TrainP(I2|I1,Sex)")
parser.add_argument('--train_pkl', type=str, default="./base_data/ground_truth.pkl")
parser.add_argument('--val_pkl', type=str, default="./base_data/ground_truth.pkl")
parser.add_argument('--interval',type=int,default=50,help='epoch interval between two evaluations')
parser.add_argument('--moddir',type=str,default='.XrayLLM/Chexray_diffusion/trained_models',help='model addresses')
parser.add_argument('--samdir',type=str,default='./generated_samples',help='sample addresses')
parser.add_argument('--genbatch',type=int,default=16,help='batch size for sampling process')
parser.add_argument('--batchsize',type=int,default=16,help='batch size per device for training Unet model')
parser.add_argument('--epoch',type=int,default=1001,help='epochs for training')

parser.add_argument('--datakey',  type=str, default="I2", help='comma separated list of keys to be concatenated in channel dimension to make a single joint model')
parser.add_argument('--labkey', type=str, default="Sex", help='which key contains the discrete conditioning')
parser.add_argument('--condkey', type=str, default="I1", help='which key contains the image conditioning')
parser.add_argument('--inch',type=int,default=3,help='input channels for Unet model')
parser.add_argument('--outch',type=int,default=3,help='output channels for Unet model')
parser.add_argument('--clsnum',type=int,default=2,help='num of label classes')

parser.add_argument('--numworkers',type=int,default=4,help='num workers for training Unet model')
parser.add_argument('--modch',type=int,default=64,help='model channels for Unet model')
parser.add_argument('--T',type=int,default=1000,help='timesteps for Unet model')
parser.add_argument('--chmul',type=list,default=[1,2,2,2],help='architecture parameters training Unet model')
parser.add_argument('--numres',type=int,default=2,help='number of resblocks for each block in Unet model')
parser.add_argument('--cdim',type=int,default=64,help='dimension of conditional embedding')
parser.add_argument('--useconv',type=bool,default=True,help='whether use convlution in downsample')
parser.add_argument('--droprate',type=float,default=0.1,help='dropout rate for model')
parser.add_argument('--dtype',default=torch.float32)
parser.add_argument('--lr',type=float,default=2e-4,help='learning rate')
parser.add_argument('--w',type=float,default=1.8,help='hyperparameters for classifier-free guidance strength')
parser.add_argument('--v',type=float,default=0.3,help='hyperparameters for the variance of posterior distribution')
parser.add_argument('--multiplier',type=float,default=2.5,help='multiplier for warmup')
parser.add_argument('--threshold',type=float,default=0.1,help='threshold for classifier-free guidance')
parser.add_argument('--num_steps',type=int,default=50,help='sampling steps for DDIM')
parser.add_argument('--eta',type=float,default=0,help='eta for variance during DDIM sampling process')
parser.add_argument('--select',type=str,default='linear',help='selection stragies for DDIM')
parser.add_argument('--ddim',type=lambda x:(str(x).lower() in ['true','1', 'yes']),default=False,help='whether to use ddim')
parser.add_argument('--local_rank',default=-1,type=int,help='node rank for distributed training')


# args = parser.parse_args()
params, unknown = parser.parse_known_args()

print('exp name', params.exp_name)

params.moddir= params.moddir+'/'+params.exp_name
params.samdir= params.samdir+'/'+params.exp_name


def cycler(loader):
    while True:
        for batch in loader:
            yield batch


def load_data(dataset: PickleDataset, batchsize: int)-> tuple[DataLoader, DistributedSampler]:
        trainloader = DataLoader(dataset,
                                 batch_size=batchsize,
                                 shuffle=True,
                                 drop_last=True)
        return trainloader

# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load data
train_data = PickleDataset(params.train_pkl)
val_data = PickleDataset(params.val_pkl)
dataloader = load_data(train_data, params.batchsize)
val_loader = load_data(val_data, params.genbatch // torch.cuda.device_count())
val_cycler = cycler(val_loader)

data_keys = params.datakey.split(',')
datacat = lambda batch: torch.cat([batch[k] for k in data_keys], dim=1).contiguous()
datashape = None

# initialize models
net = Unet(
            in_ch = params.inch,
            mod_ch = params.modch,
            out_ch = params.outch,
            ch_mul = params.chmul,
            num_res_blocks = params.numres,
            cdim = params.cdim,
            use_conv = params.useconv,
            droprate = params.droprate,
            dtype = params.dtype
        )



cemblayer = JointCelebaEmbedding(num_labels=params.clsnum,
                           d_model=params.cdim, channels=3,
                           dim=params.cdim, hw=32).to(device)

lastpath = os.path.join(params.moddir,'last_epoch.pt')
if os.path.exists(lastpath):
    lastepc = torch.load(lastpath)['last_epoch']
    # load checkpoints
    checkpoint = torch.load(os.path.join(params.moddir, f'ckpt_{lastepc}_checkpoint.pt'), map_location='cpu')
    net.load_state_dict(checkpoint['net'])
    cemblayer.load_state_dict(checkpoint['cemblayer'])
else:
    lastepc = 0

betas = get_named_beta_schedule(num_diffusion_timesteps = params.T)
diffusion = GaussianDiffusion(
                dtype = params.dtype,
                model = net,
                betas = betas,
                w = params.w,
                v = params.v,
                device = device
            )



# optimizer settings
optimizer = torch.optim.AdamW(
                itertools.chain(
                    diffusion.model.parameters(),
                    cemblayer.parameters()
                ),
                lr = params.lr,
                weight_decay = 1e-4
            )

cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
                        optimizer = optimizer,
                        T_max = params.epoch,
                        eta_min = 0,
                        last_epoch = -1
                    )
warmUpScheduler = GradualWarmupScheduler(
                        optimizer = optimizer,
                        multiplier = params.multiplier,
                        warm_epoch = params.epoch // 10,
                        after_scheduler = cosineScheduler,
                        last_epoch = lastepc
                    )
if lastepc != 0:
    optimizer.load_state_dict(checkpoint['optimizer'])
    warmUpScheduler.load_state_dict(checkpoint['scheduler'])


# training
cnt = torch.cuda.device_count()
for epc in range(lastepc, params.epoch):
    # turn into train mode
    diffusion.model.train()
    cemblayer.train()
    # batch iterations
    with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
        for batch in tqdmDataLoader:
            optimizer.zero_grad()
            x_0 = datacat(batch).to(device)
            datashape = tuple(x_0.shape[1:])
            b = x_0.shape[0]
            cemb = cemblayer(batch[params.condkey].to(device),  #user
                             batch[params.labkey].to(device))
            cemb = F.dropout1d(cemb, params.threshold)

            loss = diffusion.trainloss(x_0, cemb = cemb)
            loss.backward()
            optimizer.step()
            tqdmDataLoader.set_postfix(
                ordered_dict={
                    "epoch": epc + 1,
                    "loss: ": loss.item(),
                    "batch per device: ":x_0.shape[0],
                    "img shape: ": x_0.shape[1:],
                    "LR": optimizer.state_dict()['param_groups'][0]["lr"]
                }
            )




    warmUpScheduler.step()
    torch.cuda.empty_cache()

    # if (epc + 1) % params.interval == 0:
    if (epc + 1) % 1 == 0:
        os.makedirs(params.moddir, exist_ok=True)
        os.makedirs(params.samdir, exist_ok=True)

        diffusion.model.eval()
        cemblayer.eval()
        # generating samples
        # Generates genbatch pictures in 2 columns
        # column 0: conditioning image
        # column 1: generated image (1)
        # column 2: generated image (2)


        all_conds = []
        all_samples = []
        each_device_batch = params.genbatch // cnt
        val_batch = next(val_cycler)
        with torch.no_grad():
            cond = val_batch[params.condkey].to(device)
            cemb = cemblayer(cond,                          #user
                             val_batch[params.labkey].to(device))
            genshape = (each_device_batch ,) + datashape
            if params.ddim:
                generated = diffusion.ddim_sample(genshape, params.num_steps, params.eta, params.select, cemb = cemb)
            else:
                generated = diffusion.sample(genshape, cemb = cemb)

            cond = transback(cond)
            img = transback(generated)

            final_imgs = torch.cat([cond, img], dim=1)
            final_imgs = final_imgs.reshape(-1, 3, 64, 64).contiguous()
            save_image(final_imgs, os.path.join(params.samdir, f'generated_{epc+1}_pict.png'), nrow = 2)
            print('Image saved as ',os.path.join(params.samdir, f'generated_{epc+1}_pict.png'))


        # save checkpoints
        checkpoint = {
                    'net':diffusion.model.state_dict(),
                    'cemblayer':cemblayer.state_dict(),
                    'optimizer':optimizer.state_dict(),
                    'scheduler':warmUpScheduler.state_dict()
                }
        torch.save({'last_epoch':epc+1}, os.path.join(params.moddir,'last_epoch.pt'))
        torch.save(checkpoint, os.path.join(params.moddir, f'ckpt_{epc+1}_checkpoint.pt'))
        print('Model saved as ',os.path.join(params.moddir, f'ckpt_{epc+1}_checkpoint.pt'))


    torch.cuda.empty_cache()
