import copy

from datetime import datetime
from torch import optim
import torch.utils.tensorboard as tensorboard
from sgEncoderTraining.sgEncoder.create_sg_encoder_sd3 import create_model_and_transforms
from sgEncoderTraining.training.logger import setup_logging
from configs.configs_laion_sd3 import parse_args
from sgEncoderTraining.training.scheduler import cosine_lr
from sgEncoderTraining.training.train_and_val_one_iter_sd3 import train_by_iters
from sgEncoderTraining.datasets.laion_dataset import build_laion_loaders
import logging
import os
import torch
import torch.utils.checkpoint
from transformers import PretrainedConfig, CLIPTokenizer, T5TokenizerFast

from diffusers import (
    AutoencoderKL,
    FlowMatchEulerDiscreteScheduler,
    SD3Transformer2DModel,
)


from accelerate import Accelerator, DistributedDataParallelKwargs


def trainer():
    args = parse_args()

    accelerator = Accelerator(kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)],
                              mixed_precision=args.precision,
                              gradient_accumulation_steps=args.accumulation_steps,)

    def load_text_encoders(class_one, class_two, class_three):
        text_encoder_one = class_one.from_pretrained(
            args.stable_diffusion_checkpoint, subfolder="text_encoder", revision=args.revision, variant=args.variant,cache_dir = args.cache_dir
        )
        text_encoder_two = class_two.from_pretrained(
            args.stable_diffusion_checkpoint, subfolder="text_encoder_2", revision=args.revision, variant=args.variant,cache_dir = args.cache_dir
        )
        text_encoder_three = class_three.from_pretrained(
            args.stable_diffusion_checkpoint, subfolder="text_encoder_3", revision=args.revision, variant=args.variant,cache_dir = args.cache_dir
        )
        return text_encoder_one, text_encoder_two, text_encoder_three

    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False

    def import_model_class_from_model_name_or_path(
            pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
    ):
        text_encoder_config = PretrainedConfig.from_pretrained(
            pretrained_model_name_or_path, subfolder=subfolder, revision=revision, cache_dir = args.cache_dir
        )
        model_class = text_encoder_config.architectures[0]
        if model_class == "CLIPTextModelWithProjection":
            from transformers import CLIPTextModelWithProjection

            return CLIPTextModelWithProjection
        elif model_class == "T5EncoderModel":
            from transformers import T5EncoderModel

            return T5EncoderModel
        else:
            raise ValueError(f"{model_class} is not supported.")


    if args.name is None:
        args.name = '-'.join([
            datetime.now().strftime("%Y_%m_%d-%H_%M_%S"),
            f"lr_{args.lr}",
            f"b_{args.batch_size}",
            f"j_{args.workers}",
            f"p_{accelerator.mixed_precision}",
        ])


    args.log_path = None
    if accelerator.is_main_process:
        log_base_path = os.path.join(args.logs, args.name)
        os.makedirs(log_base_path, exist_ok=True)
        log_filename = f'out-{args.rank}' if args.log_local else 'out.log'
        args.log_path = os.path.join(log_base_path, log_filename)
        if os.path.exists(args.log_path):
            print(
                "Error. Experiment already exists. Use --name {} to specify a new experiment."
            )
            return -1

    args.log_level = logging.DEBUG if args.debug else logging.INFO
    setup_logging(args.log_path, args.log_level)

    args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to
    if accelerator.is_main_process:
        args.tensorboard_path = os.path.join(args.logs, args.name, "tensorboard") if args.tensorboard else ''
        args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints")
        for dirname in [args.tensorboard_path, args.checkpoint_path]:
            if dirname:
                os.makedirs(dirname, exist_ok=True)
    else:
        args.tensorboard_path = ''
        args.checkpoint_path = ''


    torch.manual_seed(args.seed)

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    train_dataloader, val_dataloader = build_laion_loaders(args)


    tokenizer_one = CLIPTokenizer.from_pretrained(
        args.stable_diffusion_checkpoint,
        subfolder="tokenizer",
        revision=args.revision,
        cache_dir = args.cache_dir
    )

    tokenizer_two = CLIPTokenizer.from_pretrained(
        args.stable_diffusion_checkpoint,
        subfolder="tokenizer_2",
        revision=args.revision,
        cache_dir=args.cache_dir
    )

    tokenizer_three = T5TokenizerFast.from_pretrained(
        args.stable_diffusion_checkpoint,
        subfolder="tokenizer_3",
        revision=args.revision,
        cache_dir=args.cache_dir
    )

    text_encoder_cls_one = import_model_class_from_model_name_or_path(
        args.stable_diffusion_checkpoint, args.revision
    )

    text_encoder_cls_two = import_model_class_from_model_name_or_path(
        args.stable_diffusion_checkpoint, args.revision, subfolder="text_encoder_2"
    )

    text_encoder_cls_three = import_model_class_from_model_name_or_path(
        args.stable_diffusion_checkpoint, args.revision, subfolder="text_encoder_3"
    )

    # Load scheduler and models
    noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
        args.stable_diffusion_checkpoint, subfolder="scheduler"
    )
    noise_scheduler_copy = copy.deepcopy(noise_scheduler)
    text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders(
        text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three
    )
    vae = AutoencoderKL.from_pretrained(
        args.stable_diffusion_checkpoint,
        subfolder="vae",
        revision=args.revision,
        variant=args.variant,
        cache_dir=args.cache_dir
    ).to(accelerator.device,dtype=torch.float32)

    transformer = SD3Transformer2DModel.from_pretrained(
        args.stable_diffusion_checkpoint, subfolder="transformer", revision=args.revision, variant=args.variant,cache_dir = args.cache_dir
    ).to(accelerator.device, dtype=weight_dtype)

    #transformer.enable_attention_slicing()
    #vae.enable_vae_slicing()
    transformer.enable_xformers_memory_efficient_attention()

    # We only train the additional adapter SGencoder layers
    vae.requires_grad_(False)
    text_encoder_one.requires_grad_(False)
    text_encoder_two.requires_grad_(False)
    text_encoder_three.requires_grad_(False)
    transformer.requires_grad_(False)

    text_encoder_one.to(accelerator.device, dtype=weight_dtype)
    text_encoder_two.to(accelerator.device, dtype=weight_dtype)
    text_encoder_three.to(accelerator.device, dtype=weight_dtype)

    model = create_model_and_transforms(
        args,
        text_encoders=[text_encoder_one, text_encoder_two,text_encoder_three],
        tokenizers =[tokenizer_one,tokenizer_two,tokenizer_three],
        model_config_json=args.model_config_json,
        precision=accelerator.mixed_precision,
        device=accelerator.device,
        force_quick_gelu=args.force_quick_gelu,
        pretrained_image=args.pretrained_image,
    ).to(accelerator.device, dtype=weight_dtype)

    del text_encoder_one, text_encoder_two, text_encoder_three, tokenizer_one, tokenizer_two,tokenizer_three

    #checkpoint = torch.load(args.pretrained_path, map_location=accelerator.device)
    #model.load_state_dict(checkpoint['state_dict'])

    if accelerator.is_main_process:
        logging.info("Model:")
        logging.info(f"{str(model)}")
        logging.info("Params:")
        params_file = os.path.join(args.logs, args.name, "params.txt")
        with open(params_file, "w") as f:
            for name in sorted(vars(args)):
                val = getattr(args, name)
                logging.info(f"  {name}: {val}")
                f.write(f"{name}: {val}\n")


    optimizer = None
    if args.image_dir:

        exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
        include = lambda n, p: not exclude(n, p)

        named_parameters = list(model.named_parameters())
        gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
        rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]

        optimizer = optim.AdamW(
            [
                {"params": gain_or_bias_params, "weight_decay": 0.},
                {"params": rest_params, "weight_decay": args.wd},
            ],
            lr=args.lr,
            betas=(args.beta1, args.beta2),
            eps=args.eps,
        )


    start_epoch = 0

    total_steps = len(train_dataloader) * args.epochs
    scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps)

    args.save_logs = args.logs and args.logs.lower() != 'none' and accelerator.is_main_process
    writer = None
    if args.save_logs and args.tensorboard:
        assert tensorboard is not None, "Please install tensorboard."
        writer = tensorboard.SummaryWriter(args.tensorboard_path)

        logging.debug('Finished loading wandb.')


    model, train_dataloader, val_dataloader,optimizer, scheduler, vae, transformer, noise_scheduler_copy = accelerator.prepare(
        model, train_dataloader, val_dataloader, optimizer, scheduler, vae, transformer, noise_scheduler_copy
    )

    for epoch in range(start_epoch, args.epochs):
        if accelerator.is_main_process:
            logging.info(f'Start epoch {epoch}')

        train_by_iters(model,
                       train_dataloader,
                       val_dataloader,
                       epoch,
                       optimizer,
                       scheduler,
                       args,
                       vae,
                       transformer,
                       noise_scheduler_copy,
                       accelerator,
                       writer,
                       val_count=args.val_times_per_epoch)

if __name__ == "__main__":
    trainer()
