import torch
import utils
from absl import logging
import os
import wandb
from torch.utils.data import DataLoader
from libs.eval_pipeline import SourceGroupEvaluatorMultiplier
import json
import datetime
from torch.nn.parallel import DistributedDataParallel as DDP
from model import FeedModel

def train(config):
    
    """
    prepare models
    准备各类需要的模型
    """
    accelerator, device = utils.setup(config)
    feedModel = FeedModel(device, config)
    """
    处理数据部分
    """
    ## dataset and dataloader
    train_dataset = utils.get_dataset(**config.dataset)
    train_dataset_loader = DataLoader(train_dataset,
                                      batch_size=config.batch_size,
                                      num_workers=config.num_workers,
                                      pin_memory=True,
                                      shuffle=True,
                                      drop_last=True
                                      )
    
    ## to accelerator
    feedModel.train_state.nnet, feedModel.train_state.optimizer, feedModel.train_state.feed_feedModel, train_dataset_loader, feedModel.train_state.lr_scheduler = accelerator.prepare(feedModel.train_state.nnet, 
                                                                                                                                                                                      feedModel.train_state.optimizer,
                                                                                                                                                                                      feedModel.train_state.feed_model,
                                                                                                                                                                                      train_dataset_loader,
                                                                                                                                                                                      feedModel.train_state.lr_scheduler)
    
    train_data_generator = utils.get_data_generator(train_dataset_loader, enable_tqdm=accelerator.is_main_process, desc='train')
    
    
    
    ev = SourceGroupEvaluatorMultiplier(config.eval_list, output_path=None, device=device,
                                  process_index=accelerator.process_index,
                                  num_processes=accelerator.num_processes,
                                  eval_face=True)
    ev.gen_one = feedModel.gen_one_function.__get__(ev, type(ev))

    if accelerator.is_main_process:
        logging.info("saving meta data")
        os.makedirs(config.meta_dir, exist_ok=True)
        with open(os.path.join(config.meta_dir, "config.json"), "w") as f:
            json.dump(config.to_dict(), f, indent=4)

    total_step = 0
    global_step = 0
    def train_step():
        nonlocal global_step, total_step
        feedModel.train_mode()
        iter_dict = next(train_data_generator)
        loss, loss_img, loss_clip_img = feedModel.compute_loss(iter_dict)
        accelerator.backward(loss.mean())
        feedModel.train_state.optimizer.step()
        feedModel.train_state.lr_scheduler.step()
        global_step += 1
        feedModel.train_state.optimizer.zero_grad()
        
        metrics = {}
        metrics['loss'] = accelerator.gather(loss.detach().mean()).mean().item()
        metrics['loss_img'] = accelerator.gather(loss_img.detach().mean()).mean().item()
        metrics['loss_clip_img'] = accelerator.gather(loss_clip_img.detach().mean()).mean().item()
        metrics['scale'] = accelerator.scaler.get_scale()
        metrics['lr'] = feedModel.train_state.optimizer.param_groups[0]['lr']

        return metrics

    @torch.no_grad()
    @torch.autocast(device_type='cuda')
    def evaluation(total_step):
        """
        write evaluation code here
        """
        accelerator.print("evaluation")
        eval_path = os.path.join(config.workdir, "eval", f"{total_step:06}")
        os.makedirs(eval_path, exist_ok=True)
        ev.set_output_path(eval_path)
        ev.test()
        accelerator.wait_for_everyone()
        accelerator.print("evaluation done")
        if accelerator.is_main_process:
            ev.gather_files()

    def loop():
        nonlocal global_step, total_step
        log_step = 0
        eval_step = 0
        save_step = config.save_interval
        while True:
            feedModel.eval_mode()
            metrics = train_step()
            
            total_step = global_step * config.total_batch_size
            if total_step >= eval_step:
                evaluation(total_step)
                eval_step += config.eval_interval
                accelerator.wait_for_everyone()
            
            if accelerator.is_main_process:
                feedModel.eval_mode()
                if total_step >= log_step:
                    logging.info(utils.dct2str(dict(step=total_step, **metrics)))
                    wandb.log(utils.add_prefix(metrics, 'train'), step=total_step)
                    log_step += config.log_interval


                if total_step >= save_step:
                    logging.info(f'Save and eval checkpoint {total_step}...')
                    eval_feed_model = feedModel.train_state.feed_model.module if isinstance(feedModel.train_state.feed_model, DDP) else feedModel.train_state.feed_model
                    eval_nnet = feedModel.train_state.nnet.module if isinstance(feedModel.train_state.nnet, DDP) else feedModel.train_state.nnet
                    if config.train_feed or config.train_adp:
                        torch.save(eval_feed_model.state_dict(), os.path.join(config.ckpt_root, f'{total_step:06}.pt'))
                    if config.train_nnet:
                        torch.save(eval_nnet.state_dict(), os.path.join(config.ckpt_root, f'{total_step:06}_nnet.pt'))
                    save_step += config.save_interval

            accelerator.wait_for_everyone()
            
            if total_step  >= config.max_step:
                break

    loop()

def save_source_files(target_dir):
    os.makedirs(target_dir, exist_ok=True)
    file_list = [
        f"{__file__}",
        "utils.py"
    ]

    for i in os.listdir("libs"):
        if i.endswith(".py"):
            file_list.append(os.path.join("libs", i))
    import shutil
    for f in file_list:
        shutil.copyfile(f, os.path.join(target_dir, os.path.split(f)[-1]))


from absl import flags
from absl import app
from ml_collections import config_flags

FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
    "config", None, "Training configuration.", lock_config=False)
flags.DEFINE_string("workdir", "workdir", "Work unit directory.")
flags.DEFINE_string("resume_ckpt_path", None, "The path containing the train state to resume.")
flags.DEFINE_string("logdir", "logs", "base log dir")
flags.DEFINE_string("wandb_run_prefix", None, "prefix of wandb run")
flags.DEFINE_string("wandb_mode", "offline", "offline / online")
flags.DEFINE_string("nnet_path", "models/uvit_v1.pth", "data configuration.")
flags.mark_flags_as_required(["config"])


def main(argv):
    config = FLAGS.config
    config.log_dir = FLAGS.logdir
    config.config_name = utils.get_config_name()
    config.data_name = config.dataset.name
    config.hparams = utils.get_hparams()
    now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M") # avoid process dir differs from different process, end with minute but not second
    folder_name = f"{config.config_name}-{config.data_name}-{config.hparams}-{now}"
    config.workdir = os.path.join(config.log_dir, folder_name)
    config.hparams = utils.get_hparams()
    config.ckpt_root = os.path.join(config.workdir, 'ckpts')
    config.meta_dir = os.path.join(config.workdir, "meta")
    config.resume_ckpt_path = FLAGS.resume_ckpt_path
    config.nnet_path = FLAGS.nnet_path
    os.makedirs(config.workdir, exist_ok=True)
    save_source_files(config.meta_dir)
    
    # wandb name and mode
    if FLAGS.wandb_run_prefix is not None:
        config.wandb_run_name = f"{FLAGS.wandb_run_prefix}-{config.wandb_run_name}"
    else:
        config.wandb_run_name = folder_name
    config.wandb_mode = FLAGS.wandb_mode
    
    train(config)


if __name__ == "__main__":
    app.run(main)
