import argparse
import copy
import os
import os.path as osp
import time
import torch
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import DPMSolverMultistepScheduler
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from omegaconf import OmegaConf

from mixofshow.data.lora_dataset import LoraDataset
from mixofshow.data.prompt_dataset import PromptDataset
from mixofshow.pipelines.pipeline_edlora import EDLoRAPipeline, StableDiffusionPipeline
from mixofshow.pipelines.trainer_edlora import EDLoRATrainer 
from mixofshow.utils.convert_edlora_to_diffusers import convert_edlora
from mixofshow.utils.util import MessageLogger, dict2str, reduce_loss_dict, set_path_logger
from test_edlora import visual_validation
from custom_optimizers import AdamW, AdamWr, SGD, SGDr, SoLoRA_SGD, SoLoRA, LoRAPro_SGD, LoRAPro_AdamW

# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version('0.18.2')


def train(root_path, args):

    spectral_init_done = True

    # load config
    opt = OmegaConf.to_container(OmegaConf.load(args.opt), resolve=True)

    # set accelerator, mix-precision set in the environment by "accelerate config"
    accelerator = Accelerator(gradient_accumulation_steps=opt['gradient_accumulation_steps'])

    # set experiment dir
    with accelerator.main_process_first():
        set_path_logger(accelerator, root_path, args.opt, opt, is_train=True)

    # get logger
    logger = get_logger('mixofshow', log_level='INFO')
    logger.info(accelerator.state, main_process_only=True)

    logger.info(dict2str(opt))

    # If passed along, set the training seed now.
    if opt.get('manual_seed') is not None:
        set_seed(opt['manual_seed'])

    # Load model
    EDLoRA_trainer = EDLoRATrainer(**opt['models'])    
    
    # set optimizer
    train_opt = opt['train']

    if args.optimizer=='adamw':
        print('Using Optimizer AdamW')
        optimizer = AdamW(EDLoRA_trainer.get_params_to_optimize(), **train_opt['optim_g'])
    elif args.optimizer=='scaled_adamw':
        print('Using Optimizer Scaled AdamW')
        optimizer = AdamWr(EDLoRA_trainer.get_params_to_optimize(), **train_opt['optim_g'], reg=args.optimizer_reg)
    elif args.optimizer=='sgd':
        print('Using Optimizer SGD')
        optimizer = SGD(EDLoRA_trainer.get_params_to_optimize(), **train_opt['optim_g'])
    elif args.optimizer=='scaled_gd':
        print('Using Optimizer Scaled GD')
        optimizer = SGDr(EDLoRA_trainer.get_params_to_optimize(), **train_opt['optim_g'], reg=args.optimizer_reg)
    elif args.optimizer=='solora_sgd':
        print('Using Optimizer SoLoRA SGD')
        optimizer = SoLoRA_SGD(EDLoRA_trainer.get_params_to_optimize(), **train_opt['optim_g'])
    elif args.optimizer=='solora':
        print('Using Optimizer SoLoRA')
        optimizer = SoLoRA(EDLoRA_trainer.get_params_to_optimize(), **train_opt['optim_g'])
    elif args.optimizer=='lora_pro_sgd':
        print('Using LoRA-Pro SGD')
        optimizer = LoRAPro_SGD(EDLoRA_trainer.get_params_to_optimize(), **train_opt['optim_g'], reg=args.optimizer_reg)
    elif args.optimizer=='lora_pro_adamw':
        print('Using Optimizer LoRA-Pro AdamW')
        optimizer = LoRAPro_AdamW(EDLoRA_trainer.get_params_to_optimize(), **train_opt['optim_g'], reg=args.optimizer_reg)

    # Get the training dataset
    trainset_cfg = opt['datasets']['train']
    train_dataset = LoraDataset(trainset_cfg)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=trainset_cfg['batch_size_per_gpu'], shuffle=True, drop_last=True)

    # Get the training dataset
    valset_cfg = opt['datasets']['val_vis']
    val_dataset = PromptDataset(valset_cfg)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=valset_cfg['batch_size_per_gpu'], shuffle=False)

    # Prepare everything with our `accelerator`.
    EDLoRA_trainer, optimizer, train_dataloader, val_dataloader = accelerator.prepare(EDLoRA_trainer, optimizer, train_dataloader, val_dataloader)

    # Train!
    
    total_batch_size = opt['datasets']['train']['batch_size_per_gpu'] * accelerator.num_processes * opt['gradient_accumulation_steps']
    total_iter = len(train_dataset) / total_batch_size
    opt['train']['total_iter'] = total_iter

    logger.info('***** Running training *****')
    logger.info(f'  Num examples = {len(train_dataset)}')
    logger.info(f"  Instantaneous batch size per device = {opt['datasets']['train']['batch_size_per_gpu']}")
    logger.info(f'  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}')
    logger.info(f'  Total optimization steps = {total_iter}')
    global_step = 0

    # Scheduler
    lr_scheduler = get_scheduler(
        'linear',
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=total_iter * opt['gradient_accumulation_steps'],
    )

    def make_data_yielder(dataloader):
        while True:
            for batch in dataloader:
                yield batch
            accelerator.wait_for_everyone()

    train_data_yielder = make_data_yielder(train_dataloader)

    msg_logger = MessageLogger(opt, global_step)
    stop_emb_update = False

    original_embedding = copy.deepcopy(accelerator.unwrap_model(EDLoRA_trainer).text_encoder.get_input_embeddings().weight)


    while global_step < opt['train']['total_iter']:
        with accelerator.accumulate(EDLoRA_trainer):

            accelerator.unwrap_model(EDLoRA_trainer).unet.train()
            accelerator.unwrap_model(EDLoRA_trainer).text_encoder.train()
            loss_dict = {}

            batch = next(train_data_yielder)

            if 'masks' in batch:
                masks = batch['masks']
            else:
                masks = batch['img_masks']

            loss = EDLoRA_trainer(batch['images'], batch['prompts'], masks, batch['img_masks'])
            loss_dict['loss'] = loss

            # get fix embedding and learn embedding
            index_no_updates = torch.arange(len(accelerator.unwrap_model(EDLoRA_trainer).tokenizer)) != -1
            if not stop_emb_update:
                for token_id in accelerator.unwrap_model(EDLoRA_trainer).get_all_concept_token_ids():
                    index_no_updates[token_id] = False

            accelerator.backward(loss)

            # Only set external grads on actual optimizer steps
            if accelerator.sync_gradients:
                # Provide external grads to optimizer keyed by Parameter
                lora_param_grads = accelerator.unwrap_model(EDLoRA_trainer).collect_lora_param_grads()
                # If optimizer is accelerate-wrapped, forward to the underlying .optimizer
                opt_target = optimizer.optimizer if hasattr(optimizer, 'optimizer') else optimizer
                opt_target.external_grads = lora_param_grads
                # print(len(opt_target.external_grads))
                optimizer.step()
                # Clear after step
                opt_target.external_grads = {}
                lr_scheduler.step()
                optimizer.zero_grad()

        if accelerator.sync_gradients:
            # set no update token to origin
            token_embeds = accelerator.unwrap_model(EDLoRA_trainer).text_encoder.get_input_embeddings().weight
            token_embeds.data[index_no_updates, :] = original_embedding.data[index_no_updates, :]

            token_embeds = accelerator.unwrap_model(EDLoRA_trainer).text_encoder.get_input_embeddings().weight
            concept_token_ids = accelerator.unwrap_model(EDLoRA_trainer).get_all_concept_token_ids()
            loss_dict['Norm_mean'] = token_embeds[concept_token_ids].norm(dim=-1).mean()
            if stop_emb_update is False and float(loss_dict['Norm_mean']) >= train_opt.get('emb_norm_threshold', 5.5e-1):
                stop_emb_update = True
                original_embedding = copy.deepcopy(accelerator.unwrap_model(EDLoRA_trainer).text_encoder.get_input_embeddings().weight)

            log_dict = reduce_loss_dict(accelerator, loss_dict)

        # Checks if the accelerator has performed an optimization step behind the scenes
        if accelerator.sync_gradients:
            global_step += 1
 
            if global_step % opt['logger']['print_freq'] == 0:
                log_vars = {'iter': global_step}
                log_vars.update({'lrs': lr_scheduler.get_last_lr()})
                log_vars.update(log_dict)
                msg_logger(log_vars)

            if global_step % opt['logger']['save_checkpoint_freq'] == 0:
                save_and_validation(accelerator, opt, EDLoRA_trainer, val_dataloader, global_step, logger)

    # Save the lora layers, final eval
    accelerator.wait_for_everyone()
    save_and_validation(accelerator, opt, args, EDLoRA_trainer, val_dataloader, 'latest_potter_model', logger)


def save_and_validation(accelerator, opt, args, EDLoRA_trainer, val_dataloader, global_step, logger):
    enable_edlora = opt['models']['enable_edlora']
    lora_type = 'edlora' if enable_edlora else 'lora'
    save_path = os.path.join(opt['path']['models'], args.optimizer, f'{lora_type}_model-{global_step}.pth')

    if accelerator.is_main_process:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        accelerator.save({'params': accelerator.unwrap_model(EDLoRA_trainer).delta_state_dict()}, save_path)
        logger.info(f'Save state to {save_path}')

    accelerator.wait_for_everyone()



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('-opt', type=str, default='options/train/EDLoRA/EDLoRA_hina_Anyv4_B4_Iter1K.yml')
    parser.add_argument('--optimizer', default='solora', type=str,
                        choices=['adamw', 'scaled_adamw', 'sgd', 'scaled_gd', 'solora_sgd', 'solora', 'lora_pro_sgd', 'lora_pro_adamw'])
    parser.add_argument('--optimizer_reg', default=0.0, type=float)
    parser.add_argument('--local-rank', type=int, default=0)
    args = parser.parse_args()

    root_path = osp.abspath(osp.join(__file__, osp.pardir))

    start_time = time.time()

    train(root_path, args)

    end_time = time.time()
    print(f"Total training time: {(end_time - start_time) / 3600} hours")
