import os, sys
import argparse, os, sys, datetime, glob, importlib
from tabnanny import verbose
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import numpy as np

import torch
import argparse
from model.utils import add_parent_path, set_seeds

# Exp & Model
from models import get_model, get_model_id, add_model_args
from solver import ExperimentPL, add_exp_args

# Data
add_parent_path(level=1)
from modules.datasets.data import get_data, get_data_id, add_data_args,get_pl_datamodule

# Optim
from model.wdecay import get_optim, get_optim_id, add_optim_args

from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
from pytorch_lightning import seed_everything
from PIL import Image
import torch
import torchvision
import pytorch_lightning
import pytorch_lightning as pl
from omegaconf import OmegaConf
import wandb
from pytorch_lightning.utilities.distributed import rank_zero_only
# import torch
# torch.autograd.set_detect_anomaly(True)

###########
## Setup ##
###########

parser = argparse.ArgumentParser()
parser.add_argument('--debug', type=int, default=0)
parser.add_argument('--gpus', type=int, default=8)
parser.add_argument('--cond', default=None)
parser.add_argument('--ckpt_path',type=str, default=None)
# parser.add_argument('--gpus', nargs='+', type=int)
add_exp_args(parser)
add_data_args(parser)
add_model_args(parser)
add_optim_args(parser)
args = parser.parse_args()
set_seeds(args.seed)
seed_everything(args.seed)

##################
## Specify data ##
##################

data_module = get_pl_datamodule(args)
data_module.setup()
data_id = get_data_id(args)

###################
## Specify model ##
###################

model = get_model(args)
model_id = get_model_id(args)

#######################
## Specify optimizer ##
#######################

optimizer, scheduler = get_optim(args, model)
optim_id = get_optim_id(args)


now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
nowname = now + '_' + model_id+ '_' + data_id + '_' + optim_id
check_path = os.path.join(args.log_home, 'check')
# default_logger_cfgs = {
#     "wandb": {
#         "target": "pytorch_lightning.loggers.WandbLogger",
#         "params": {
#             "name": nowname,
#             "save_dir": args.log_home,
#             "id": nowname,
#         }
#     },
#     "testtube": {
#         "target": "pytorch_lightning.loggers.TestTubeLogger",
#         "params": {
#             "name": "testtube",
#             "save_dir": args.log_home,
#         }
#     },
# }

# default_modelckpt_cfg = {
#     "target": "pytorch_lightning.callbacks.ModelCheckpoint",
#     "params": {
#         "dirpath": check_path,
#         "filename": "{epoch:06}",
#         "verbose": True,
#         "save_last": True,
#     }
# }


# # add callback which sets up log directory
# default_callbacks_cfg = {
#     "setup_callback": {
#         "target": "main.SetupCallback",
#         "params": {
#             "now": now,
#             "logdir": args.log_home,
#             "ckptdir": check_path,
#             "cfgdir": args.log_home,
#         }
#     },
#     "image_logger": {
#         "target": "ImageLogger",
#         "params": {
#             "batch_frequency": 750,
#             "max_images": 4,
#             "clamp": True
#         }
#     },
#     "learning_rate_logger": {
#         "target": "main.LearningRateMonitor",
#         "params": {
#             "logging_interval": "step",
#             #"log_momentum": True
#         }
#     },
# }

# class SetupCallback(Callback):
#     def __init__(self, now, logdir, ckptdir):
#         super().__init__()
#         # self.resume = resume
#         self.now = now
#         self.logdir = logdir
#         self.ckptdir = ckptdir
#         # self.cfgdir = cfgdir
#         # self.config = config
#         # self.lightning_config = lightning_config

#     def on_pretrain_routine_start(self, trainer, pl_module):
#         if trainer.global_rank == 0:
#             # Create logdirs and save configs
#             os.makedirs(self.logdir, exist_ok=True)
#             os.makedirs(self.ckptdir, exist_ok=True)
#             # os.makedirs(self.cfgdir, exist_ok=True)

#             # print("Project config")
#             # print(self.config.pretty())
#             # OmegaConf.save(self.config,
#             #                os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))

#             # print("Lightning config")
#             # print(self.lightning_config.pretty())
#             # OmegaConf.save(OmegaConf.create({"lightning": self.lightning_config}),
#             #                os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)))

#         else:
#             # ModelCheckpoint callback created log directory --- remove it
#             if os.path.exists(self.logdir):
#                 dst, name = os.path.split(self.logdir)
#                 dst = os.path.join(dst, "child_runs", name)
#                 os.makedirs(os.path.split(dst)[0], exist_ok=True)
#                 try:
#                     os.rename(self.logdir, dst)
#                 except FileNotFoundError:
#                     pass


# class ImageLogger(Callback):
#     def __init__(self, batch_frequency, max_images, clamp=True, increase_log_steps=True):
#         super().__init__()
#         self.batch_freq = batch_frequency
#         self.max_images = max_images
#         self.logger_log_images = {
#             pl.loggers.WandbLogger: self._wandb,
#             pl.loggers.TestTubeLogger: self._testtube,
#         }
#         self.log_steps = [2 ** n for n in range(int(np.log2(self.batch_freq)) + 1)]
#         if not increase_log_steps:
#             self.log_steps = [self.batch_freq]
#         self.clamp = clamp

#     @rank_zero_only
#     def _wandb(self, pl_module, images, batch_idx, split):
#         grids = dict()
#         for k in images:
#             grid = torchvision.utils.make_grid(images[k])
#             grids[f"{split}/{k}"] = wandb.Image(grid)
#         pl_module.logger.experiment.log(grids)

#     @rank_zero_only
#     def _testtube(self, pl_module, images, batch_idx, split):
#         for k in images:
#             grid = torchvision.utils.make_grid(images[k])
#             grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w

#             tag = f"{split}/{k}"
#             pl_module.logger.experiment.add_image(
#                 tag, grid,
#                 global_step=pl_module.global_step)

#     @rank_zero_only
#     def log_local(self, save_dir, split, images,
#                   global_step, current_epoch, batch_idx):
#         root = os.path.join(save_dir, "images", split)
#         for k in images:
#             grid = torchvision.utils.make_grid(images[k], nrow=4)

#             grid = (grid+1.0)/2.0 # -1,1 -> 0,1; c,h,w
#             grid = grid.transpose(0,1).transpose(1,2).squeeze(-1)
#             grid = grid.numpy()
#             grid = (grid*255).astype(np.uint8)
#             filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
#                 k,
#                 global_step,
#                 current_epoch,
#                 batch_idx)
#             path = os.path.join(root, filename)
#             os.makedirs(os.path.split(path)[0], exist_ok=True)
#             Image.fromarray(grid).save(path)

#     def log_img(self, pl_module, batch, batch_idx, split="train"):
#         if (self.check_frequency(batch_idx) and  # batch_idx % self.batch_freq == 0
#                 hasattr(pl_module, "log_images") and
#                 callable(pl_module.log_images) and
#                 self.max_images > 0):
#             logger = type(pl_module.logger)

#             is_train = pl_module.training
#             if is_train:
#                 pl_module.eval()

#             with torch.no_grad():
#                 images = pl_module.log_images(batch, split=split, pl_module=pl_module)

#             for k in images:
#                 N = min(images[k].shape[0], self.max_images)
#                 images[k] = images[k][:N]
#                 if isinstance(images[k], torch.Tensor):
#                     images[k] = images[k].detach().cpu()
#                     if self.clamp:
#                         images[k] = torch.clamp(images[k], -1., 1.)

#             self.log_local(pl_module.logger.save_dir, split, images,
#                            pl_module.global_step, pl_module.current_epoch, batch_idx)

#             logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
#             logger_log_images(pl_module, images, pl_module.global_step, split)

#             if is_train:
#                 pl_module.train()

#     def check_frequency(self, batch_idx):
#         # if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
#         if (batch_idx % self.batch_freq) == 0 :
#             try:
#                 self.log_steps.pop(0)
#             except IndexError:
#                 pass
#             return True
#         return False

#     def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
#         if self.check_frequency(batch_idx):
#             self.log_img(pl_module, batch, batch_idx, split="train")


#     def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
#         if self.check_frequency(batch_idx):
#             self.log_img(pl_module, batch, batch_idx, split="val")


##############
## Training ##
##############

trainer_kwargs = dict()
trainer_kwargs['devices']= args.gpus
# trainer_kwargs['devices']= 1
trainer_kwargs["accelerator"] = "gpu"
trainer_kwargs["strategy"] = "ddp"
trainer_kwargs["logger"] = pytorch_lightning.loggers.WandbLogger(name=nowname,save_dir=args.log_home,id=nowname)
# trainer_kwargs["checkpoint_callback"] = True
# trainer_kwargs["resume_from_checkpoint"]='/home/hu/UniDm/UniDiff/cub_epoch999_new.ckpt'
trainer_kwargs["max_epochs"] = args.epochs
trainer_kwargs["callbacks"] = [
    # SetupCallback(now=now,logdir=args.log_home,ckptdir=check_path),
                                # pytorch_lightning.callbacks.ModelCheckpoint(dirpath=check_path,filename='{epoch:06}',verbose=True,save_last=True),
                                # ImageLogger(batch_frequency=750, max_images=4,clamp=True),
                                LearningRateMonitor(logging_interval='step')
]


# from collections import OrderedDict
# new_state_dict = OrderedDict()

# ckpt_path='/home/hu/UniDm/2022-09-16T04-53-34_UNIDIFFUSION_cub200_expdecay/2022-09-16T04-53-34_UNIDIFFUSION_cub200_expdecay/checkpoints/epoch=270-step=120053.ckpt'
# checkpoint = torch.load(ckpt_path,map_location='cpu')
# ckpt = {k[6:]: checkpoint['state_dict'][k] for k in checkpoint['state_dict'].keys()}
# model.load_state_dict(ckpt)

if args.cond is not None:
    cond = args.cond
else:
    cond = None
exp = ExperimentPL(model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    cond=cond
                    )

trainer = Trainer(**trainer_kwargs)
trainer.fit(exp, data_module)
