import logging
import math
import os
import time
import torch.nn.functional as F
from diffusers.training_utils import compute_density_for_timestep_sampling, compute_loss_weighting_for_sd3

from sgEncoderTraining.global_var import *
from tqdm import tqdm


def unwrap_model(model):
    if hasattr(model, 'module'):
        return model.module
    else:
        return model



class AverageMeter(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def validate_by_iter(step,
                     model,
                     dataloader,
                     epoch,
                     args,
                     vae,
                     transformer,
                     noise_scheduler_copy,
                     accelerator,
                     tb_writer=None):

    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
        timesteps = timesteps.to(accelerator.device)
        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

        sigma = sigmas[step_indices].flatten()
        while len(sigma.shape) < n_dim:
            sigma = sigma.unsqueeze(-1)
        return sigma


    model.eval()

    batch_size = args.val_batch_size

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    total_loss_m = AverageMeter()

    if accelerator.is_main_process:
        dataloader = tqdm(dataloader, desc=f"VAL Epoch {epoch + 1}/{args.epochs}_{step}")

    for i, batch in enumerate(dataloader):

        all_imgs, all_triples, all_global_ids, all_isolated_items, all_text_prompts, all_original_sizes, all_crop_top_lefts, all_img_ids = [
            x for x in batch]



        all_imgs = torch.cat(all_imgs, dim=0).to(accelerator.device)
        with torch.no_grad():
            model_input = vae.encode(all_imgs).latent_dist.sample()
            model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
            model_input = model_input.to(dtype=weight_dtype)

            # Sample noise that we'll add to the latents
            noise = torch.randn_like(model_input)
            bsz = model_input.shape[0]

            # Sample a random timestep for each image
            # for weighting schemes where we sample timesteps non-uniformly
            u = compute_density_for_timestep_sampling(
                weighting_scheme=args.weighting_scheme,
                batch_size=bsz,
                logit_mean=args.logit_mean,
                logit_std=args.logit_std,
                mode_scale=args.mode_scale,
            )
            indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
            timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)

            # Add noise according to flow matching.
            # zt = (1 - texp) * x + texp * z1
            sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
            noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise

            prompt_embeds, pooled_embeds = model(all_triples, all_isolated_items, all_global_ids)
            model_pred = transformer(
                hidden_states=noisy_model_input,
                timestep=timesteps,
                encoder_hidden_states=prompt_embeds,
                pooled_projections=pooled_embeds,
                return_dict=False,
            )[0]
            # Follow: Section 5 of https://huggingface.co/papers/2206.00364.
            # Preconditioning of the model outputs.
            model_pred = model_pred * (-sigmas) + noisy_model_input
            # these weighting schemes use a uniform timestep sampling
            # and instead post-weight the loss
            weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
            target = model_input
            loss = torch.mean(
                (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
                1,
            )
            val_loss = loss.mean()

        if accelerator.is_main_process:
            total_loss_m.update(val_loss.item(), batch_size)

    if accelerator.is_main_process:

        logging.info(
            f"Validate Epoch: {epoch} "
            f"Validate Iteration: {step} "
            f"Validate Batch Size: {batch_size} "
            f"Average mse loss of ori_img2sdxl_img: {total_loss_m.avg:#.5g} "
        )

        log_data = {
            "val_loss": total_loss_m.avg,
        }
        for name, val in log_data.items():
            name = "validate/" + name
            if tb_writer is not None:
                tb_writer.add_scalar(name, val, step)


def train_by_iters(model,
                   dataloader,
                   val_dataloader,
                   epoch,
                   optimizer,
                   scheduler,
                   args,
                   vae,
                   transformer,
                   noise_scheduler_copy,
                   accelerator,
                   tb_writer=None,
                   val_count=10):
    def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
        sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
        schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
        timesteps = timesteps.to(accelerator.device)
        step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]

        sigma = sigmas[step_indices].flatten()
        while len(sigma.shape) < n_dim:
            sigma = sigma.unsqueeze(-1)
        return sigma

    model.train()  # Ensure model is in training mode

    num_batches_per_epoch = len(dataloader)
    val_frequency = num_batches_per_epoch // val_count

    sample_digits = math.ceil(math.log(num_batches_per_epoch + 1, 10))

    total_loss_m = AverageMeter()
    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    end = time.time()

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    # Only wrap with tqdm on main process
    if accelerator.is_main_process:
        dataloader = tqdm(dataloader, total=num_batches_per_epoch, desc=f"Epoch {epoch + 1}/{args.epochs}")

    for i, batch in enumerate(dataloader):
        # Accelerator automatically handles when to zero gradients with gradient accumulation
        with accelerator.accumulate(model):
            step = num_batches_per_epoch * epoch + i
            scheduler(step)

            all_imgs, all_triples, all_global_ids, all_isolated_items, all_text_prompts, all_original_sizes, all_crop_top_lefts, all_img_ids = [
                x for x in batch]


            all_imgs = torch.cat(all_imgs, dim=0).to(accelerator.device)

            data_time_m.update(time.time() - end)

            # Encode images with no_grad to save memory
            with torch.no_grad():
                model_input = vae.encode(all_imgs).latent_dist.sample()
                model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
                model_input = model_input.to(dtype=weight_dtype)

            # Sample noise that we'll add to the latents
            noise = torch.randn_like(model_input)
            bsz = model_input.shape[0]

            # Sample timesteps
            u = compute_density_for_timestep_sampling(
                weighting_scheme=args.weighting_scheme,
                batch_size=bsz,
                logit_mean=args.logit_mean,
                logit_std=args.logit_std,
                mode_scale=args.mode_scale,
            )
            indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
            timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)

            # Add noise according to flow matching
            sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
            noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise

            # Forward pass through SG encoder
            prompt_embeds, pooled_embeds = model(all_triples, all_isolated_items, all_global_ids)

            # Forward pass through transformer
            model_pred = transformer(
                hidden_states=noisy_model_input,
                timestep=timesteps,
                encoder_hidden_states=prompt_embeds,
                pooled_projections=pooled_embeds,
                return_dict=False,
            )[0]

            # Preconditioning of the model outputs
            model_pred = model_pred * (-sigmas) + noisy_model_input

            # Compute loss
            weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
            target = model_input

            # Detach target to prevent gradient flow through VAE
            loss = torch.mean(
                (weighting.float() * (model_pred.float() - target.float().detach()) ** 2).reshape(target.shape[0], -1),
                1,
            )
            loss = loss.mean()

            # Accelerator handles backward, optimizer step, and zero_grad automatically
            accelerator.backward(loss)

            if i%10 ==0:
                print("alpha now is:" , model.module.get_alpha())

            # Gradient clipping (only on steps where optimizer will step)
            if accelerator.sync_gradients:
                accelerator.clip_grad_norm_(model.parameters(), args.norm_gradient_clip)

            # Optimizer step is handled by accelerator
            optimizer.step()
            optimizer.zero_grad()

            # Clamp logit scale
            with torch.no_grad():
                unwrap_model(model).logit_scale.clamp_(0, math.log(100))

        # Update metrics
        batch_time_m.update(time.time() - end)
        end = time.time()
        batch_count = i + 1

        # Update loss meter
        total_loss_m.update(loss.detach().item(), args.batch_size)

        # Logging (only log when we actually update weights)
        if accelerator.sync_gradients and accelerator.is_main_process:
            num_samples = batch_count * args.batch_size
            percent_complete = 100.0 * batch_count / num_batches_per_epoch

            logging.info(
                f"Train Epoch: {epoch} [{num_samples:>{sample_digits}} ({percent_complete:.0f}%)] "
                f"Total Loss: {total_loss_m.val:#.5g} ({total_loss_m.avg:#.4g}) "
                f"Data (t): {data_time_m.avg:.3f} "
                f"Batch (t): {batch_time_m.avg:.3f}, {args.batch_size / batch_time_m.val:#g}/s "
                f"LR: {optimizer.param_groups[0]['lr']:5f} "
            )

            log_data = {
                "batch_loss": total_loss_m.val,
                "data_time": data_time_m.val,
                "batch_time": batch_time_m.val,
                "samples_per_second": args.batch_size / batch_time_m.val,
                "lr": optimizer.param_groups[0]["lr"],
            }

            for name, val in log_data.items():
                name = "train/" + name
                if tb_writer is not None:
                    tb_writer.add_scalar(name, val, step)

        # Validation
        if (step + 1) % val_frequency == 0 or (step + 1) % num_batches_per_epoch == 0:
            if accelerator.is_main_process:
                if tb_writer is not None:
                    avg_loss = total_loss_m.avg
                    tb_writer.add_scalar("iter_avg_loss", avg_loss, step)
                    tb_writer.flush()
                    logging.info(f"Average Loss for Epoch {epoch} Iter {step}: {avg_loss:.4f}")

                completed_epoch = epoch + 1

                if args.save_logs:
                    # Use unwrap_model to get the original model
                    checkpoint_dict = {
                        "epoch": completed_epoch,
                        "iteration": step,
                        "name": args.name,
                        "state_dict": unwrap_model(model).state_dict(),
                        "optimizer": optimizer.state_dict(),
                    }

                    if completed_epoch == args.epochs or (
                            args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0
                    ):
                        torch.save(
                            checkpoint_dict,
                            os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}_iter_{step}.pt"),
                        )

            # Run validation
            model.eval()
            with torch.no_grad():
                validate_by_iter(step, model, val_dataloader, epoch, args, vae,
                                 transformer, noise_scheduler_copy, accelerator, tb_writer)

            # Set model back to training mode after validation
            model.train()

