import json
import logging
import math
import os
import time

import numpy as np
import torch
import torch.distributed
import torch.nn as nn
import torch.nn.functional as F

try:
    import wandb
except ImportError:
    wandb = None

from ..open_clip import get_input_dtype
from .distributed import is_master
from .zero_shot import zero_shot_eval
from .precision import get_autocast
from utils import save_images
from vqgan import create_vqgan_loss


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def postprocess_clip_output(model_out):
    return {
        "image_features": model_out[0],
        "text_features": model_out[1],
        "logit_scale": model_out[2]
    }


def unwrap_model(model):
    if hasattr(model, 'module'):
        return model.module
    else:
        return model


def backward(total_loss, scaler):
    if scaler is not None:
        scaler.scale(total_loss).backward()
    else:
        total_loss.backward()


def calculate_adaptive_weight(nll_loss, g_loss, last_layer):
    nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
    g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]

    d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
    # d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
    d_weight = torch.clamp(d_weight, 0.3, 10).detach()
    return d_weight


def gan_loss(inputs, recons, discriminator, d_loss, g_loss, device, mode=None):
    gen_loss = g_loss
    disc_loss = d_loss
    
    loss_gen = torch.zeros((), device=device)
    loss_disc = torch.zeros((), device=device)

    logits_avg = {}

    if mode == 'gen':
        logits_fake, _ = discriminator(recons.contiguous(), None)
        loss_gen = gen_loss(logits_fake)
        logits_avg['logits_fake'] = logits_fake.detach().mean()

    elif mode == 'disc':
        logits_fake, logits_real = discriminator(recons.contiguous().detach(), inputs.contiguous().detach())
        loss_disc = disc_loss(logits_real, logits_fake)

        logits_avg['logits_real'] = logits_real.detach().mean()
        logits_avg['logits_fake'] = logits_fake.detach().mean()

    elif mode == 'eval':
        logits_fake, logits_real = discriminator(recons.contiguous().detach(), inputs.contiguous().detach())

        loss_gen = gen_loss(logits_fake)
        loss_disc = disc_loss(logits_real, logits_fake)

        logits_avg['logits_real'] = logits_real.detach().mean()
        logits_avg['logits_fake'] = logits_fake.detach().mean()

    return loss_gen, loss_disc, logits_avg


def train_one_epoch(model, original_state, data, loss, epoch, optimizer, scaler, scheduler, 
                    discriminator, disc_optimizer, disc_scheduler, d_loss, g_loss, p_loss,
                    args, tb_writer=None):
    
    device = torch.device(args.device)
    autocast = get_autocast(args.precision, device_type=device.type)
    input_dtype = get_input_dtype(args.precision)

    print(f"args.precision: {args.precision}")
    print(f"autocast: {autocast}")
    print(f"input_dtype: {input_dtype}")

    model.train()

    if hasattr(model, 'module'):
        model.module.siglip_model.gradient_checkpointing_enable({"use_reentrant": False})
    else:
        model.siglip_model.gradient_checkpointing_enable({"use_reentrant": False})

    data['train'].set_epoch(epoch)  # set epoch in process safe manner via sampler or shared_epoch
    dataloader = data['train'].dataloader
    num_batches_per_epoch = dataloader.num_batches // args.accum_freq
    sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10))

    gan_start_epoch = args.gan_start_epoch
    d_weight = 0.75
    p_weight = 1.0

    losses_m = {}
    batch_time_m = AverageMeter()
    data_time_m = AverageMeter()
    end = time.time()
    for i, batch in enumerate(dataloader):
        i_accum = i // args.accum_freq
        step = num_batches_per_epoch * epoch + i_accum

        if not args.skip_scheduler:
            scheduler(step)

        # images, texts = batch
        images = batch[0]
        images = images.to(device=device, dtype=input_dtype, non_blocking=True)
        # texts = texts.to(device=device, non_blocking=True)

        data_time_m.update(time.time() - end)
        optimizer.zero_grad()

        # print(f"lalala: {model.dtype}")

        with autocast():
            if args.stage == 1:
                # print(f"popopo: {model.dtype}")
                clip_loss_dict, hidden_state_26, pooler_output, img_recon, code_visual, code_semantic, quant_loss_visual, quant_loss_semantic = model(images, None)
                original_state = original_state.to(device=device)
                with torch.no_grad():
                    outputs_gt = original_state(pixel_values=images, output_hidden_states=True)
                    hidden_state_26_gt = outputs_gt.hidden_states[-2]
                    pooler_output_gt = outputs_gt.pooler_output  # pooled features
                
                semantic_loss_26 = nn.MSELoss()(hidden_state_26, hidden_state_26_gt)
                semantic_loss_pooler = nn.MSELoss()(pooler_output, pooler_output_gt)
                semantic_loss = 5 * semantic_loss_26 + semantic_loss_pooler
                losses = {"semantic_loss_26th-layer": semantic_loss_26,
                          "semantic_loss_pooler": semantic_loss_pooler}
                sem_weight = 1
            # elif args.stage == 2:
            #     clip_loss_dict, pooler_output, img_recon, code_visual, code_semantic, quant_loss_visual, quant_loss_semantic = model(images, texts)
            #     losses_clip = loss(**clip_loss_dict, output_dict=True)  # contrastive loss
            #     semantic_loss = sum(losses_clip.values())
            #     losses = losses_clip  # Contrastive_loss
            #     sem_weight = 0.1
            else:
                raise NotImplementedError("args.stage should be 1 or 2.")

            logit_scale = clip_loss_dict["logit_scale"]

            rec_pixl_loss = nn.MSELoss()(img_recon, images)
            losses["recon_loss"] = rec_pixl_loss
            losses["quant_loss_visual"] = quant_loss_visual
            losses["quant_loss_semantic"] = quant_loss_semantic

            perceptual_loss = p_loss.to(device).eval()
            loss_pcpt = perceptual_loss(img_recon, images)
            losses["pcpt_loss"] = loss_pcpt

            if epoch < gan_start_epoch or (epoch == gan_start_epoch and i < num_batches_per_epoch / 2):
                use_discriminator = False
            else:
                use_discriminator = True
                discriminator.train()

            if use_discriminator:
                loss_gen, _, logits_avg = gan_loss(images, img_recon, discriminator, None, g_loss, device, mode='gen')
                losses["gen_loss"] = loss_gen

                if hasattr(model, 'module'):
                    get_last_layer = model.module.decoder.conv_out.weight
                else:
                    get_last_layer = model.decoder.conv_out.weight

                g_weight = calculate_adaptive_weight(rec_pixl_loss + p_weight * loss_pcpt,
                                                     loss_gen,
                                                     last_layer=get_last_layer)
            else:
                loss_gen = torch.zeros((), device=device)
                g_weight = torch.zeros((), device=device)
            
            total_loss = sem_weight * semantic_loss + 10 * rec_pixl_loss + 40 * (quant_loss_visual + quant_loss_semantic) + p_weight * loss_pcpt + g_weight * d_weight * loss_gen
            
        # has_nan = torch.isnan(images).any()
        # logging.info(f"input images has nan: {has_nan}")
        # has_nan = torch.isnan(img_recon).any()
        # logging.info(f"img_recon has nan: {has_nan}")
        # logging.info(f"logits_fake: {logits_avg['logits_fake']}")
        
        if torch.isnan(total_loss) or torch.isinf(total_loss):
            has_nan = torch.isnan(images).any()
            logging.info(f"input images has nan: {has_nan}")
            has_nan = torch.isnan(img_recon).any()
            logging.info(f"img_recon has nan: {has_nan}")
            logging.info(f"logits_fake: {logits_avg['logits_fake']}")
        
        losses["total_loss"] = total_loss
        backward(total_loss, scaler)

        if scaler is not None:
            if args.grad_clip_norm is not None:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            if args.grad_clip_norm is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip_norm, norm_type=2.0)
            optimizer.step()

        # discriminator loss
        if use_discriminator:
            disc_optimizer.zero_grad()
            with autocast():
                _, loss_disc, logits_avg = gan_loss(images, img_recon, discriminator, d_loss, None, device, mode='disc')
                
                if torch.isnan(loss_disc) or torch.isinf(loss_disc):
                    logging.info(f"disc logits_fake: {logits_avg['logits_fake']}")
                    logging.info(f"disc logits_real: {logits_avg['logits_real']}")

                losses["disc_loss"] = loss_disc
                if epoch == gan_start_epoch:
                    loss_disc.backward()
                else:
                    (d_weight * loss_disc).backward()

                disc_optimizer.step()
                disc_scheduler.step()

        # Note: we clamp to 4.6052 = ln(100), as in the original paper.
        # with torch.no_grad():
        #     unwrap_model(model).siglip_model.logit_scale.clamp_(0, math.log(100))

        batch_time_m.update(time.time() - end)
        end = time.time()
        batch_count = i_accum + 1
        if is_master(args) and (i_accum % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch):
            batch_size = len(images)
            num_samples = batch_count * batch_size * args.accum_freq * args.world_size
            samples_per_epoch = dataloader.num_samples
            percent_complete = 100.0 * batch_count / num_batches_per_epoch

            # NOTE loss is coarsely sampled, just master node and per log update
            for key, val in losses.items():
                if key not in losses_m:
                    losses_m[key] = AverageMeter()
                losses_m[key].update(val.item(), batch_size)

            logit_scale_scalar = logit_scale.item()
            loss_log = " ".join(
                [
                    f"{loss_name.capitalize()}: {loss_m.val:#.5g} ({loss_m.avg:#.5g})" 
                    for loss_name, loss_m in losses_m.items()
                ]
            )
            samples_per_second = args.accum_freq * args.batch_size * args.world_size / batch_time_m.val
            samples_per_second_per_gpu = args.accum_freq * args.batch_size / batch_time_m.val
            logging.info(
                f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] "
                f"Data (t): {data_time_m.avg:.3f} "
                f"Batch (t): {batch_time_m.avg:.3f}, {samples_per_second:#g}/s, {samples_per_second_per_gpu:#g}/s/gpu "
                f"LR: {optimizer.param_groups[0]['lr']:5f} "
                f"Logit Scale: {logit_scale_scalar:.3f} " + loss_log
            )

            # Save train loss / etc. Using non avg meter values as loggers have their own smoothing
            log_data = {
                "data_time": data_time_m.val,
                "batch_time": batch_time_m.val,
                "samples_per_second": samples_per_second,
                "samples_per_second_per_gpu": samples_per_second_per_gpu,
                "scale": logit_scale_scalar,
                "lr": optimizer.param_groups[0]["lr"]
            }            
            log_data.update({name:val.val for name,val in losses_m.items()})

            log_data = {"train/" + name: val for name, val in log_data.items()}

            if tb_writer is not None:
                for name, val in log_data.items():
                    tb_writer.add_scalar(name, val, step)
            
            if args.wandb:
                assert wandb is not None, 'Please install wandb.'
                log_data['step'] = step  # for backwards compatibility
                wandb.log(log_data, step=step)
            
            # resetting batch / data time meters per log window
            batch_time_m.reset()
            data_time_m.reset()

            if i % 5000 == 0:
                imgsave_path = os.path.join(args.imgsave_path, f"epoch_{epoch}")
                os.makedirs(imgsave_path, exist_ok=True)
                logging.info(f"g_weight: {g_weight}")
                image_save_path = f'{imgsave_path}/{i}.jpg'
                save_images(img_recon, image_save_path)
                image_save_path = f'{imgsave_path}/{i}_gt.jpg'
                save_images(images, image_save_path)
    # end for


def evaluate(model, data, epoch, args, tb_writer=None, tokenizer=None):
    metrics = {}
    if not is_master(args):
        return metrics
    device = torch.device(args.device)
    model.eval()

    zero_shot_metrics = zero_shot_eval(model, data, epoch, args, tokenizer=tokenizer)
    metrics.update(zero_shot_metrics)

    autocast = get_autocast(args.precision, device_type=device.type)
    input_dtype = get_input_dtype(args.precision)

    if 'val' in data and (args.val_frequency and ((epoch % args.val_frequency) == 0 or epoch == args.epochs)):
        dataloader = data['val'].dataloader
        num_samples = 0
        samples_per_val = dataloader.num_samples

        # FIXME this does not scale past small eval datasets
        # all_image_features @ all_text_features will blow up memory and compute very quickly
        cumulative_loss = 0.0
        cumulative_gen_loss = 0.0
        all_image_features, all_text_features = [], []
        with torch.inference_mode():
            for i, batch in enumerate(dataloader):
                images, texts = batch
                images = images.to(device=device, dtype=input_dtype, non_blocking=True)
                texts = texts.to(device=device, non_blocking=True)

                with autocast():
                    model_out = model(images, texts)
                    image_features = model_out["image_features"]
                    text_features = model_out["text_features"]
                    logit_scale = model_out["logit_scale"]
                    # features are accumulated in CPU tensors, otherwise GPU memory exhausted quickly
                    # however, system RAM is easily exceeded and compute time becomes problematic
                    all_image_features.append(image_features.cpu())
                    all_text_features.append(text_features.cpu())
                    logit_scale = logit_scale.mean()
                    logits_per_image = logit_scale * image_features @ text_features.t()
                    logits_per_text = logits_per_image.t()

                    batch_size = images.shape[0]
                    labels = torch.arange(batch_size, device=device).long()
                    total_loss = (
                        F.cross_entropy(logits_per_image, labels) +
                        F.cross_entropy(logits_per_text, labels)
                    ) / 2

                    gen_loss = maybe_compute_generative_loss(model_out)

                cumulative_loss += total_loss * batch_size
                num_samples += batch_size
                if is_master(args) and (i % 100) == 0:
                    logging.info(
                        f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]\t"
                        f"Clip Loss: {cumulative_loss / num_samples:.6f}\t")

                    if gen_loss is not None:
                        cumulative_gen_loss += gen_loss * batch_size
                        logging.info(
                            f"Generative Loss: {cumulative_gen_loss / num_samples:.6f}\t")

            val_metrics = get_clip_metrics(
                image_features=torch.cat(all_image_features),
                text_features=torch.cat(all_text_features),
                logit_scale=logit_scale.cpu(),
            )
            loss = cumulative_loss / num_samples
            metrics.update(
                {**val_metrics, "clip_val_loss": loss.item(), "epoch": epoch, "num_samples": num_samples}
            )
            if gen_loss is not None:
                gen_loss = cumulative_gen_loss / num_samples
                metrics.update({"val_generative_loss": gen_loss.item()})

    if not metrics:
        return metrics

    logging.info(
        f"Eval Epoch: {epoch} "
        + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
    )

    log_data = {"val/" + name: val for name, val in metrics.items()}

    if args.save_logs:
        if tb_writer is not None:
            for name, val in log_data.items():
                tb_writer.add_scalar(name, val, epoch)

        with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f:
            f.write(json.dumps(metrics))
            f.write("\n")

    if args.wandb:
        assert wandb is not None, 'Please install wandb.'
        if 'train' in data:
            dataloader = data['train'].dataloader
            num_batches_per_epoch = dataloader.num_batches // args.accum_freq
            step = num_batches_per_epoch * epoch
        else:
            step = None
        log_data['epoch'] = epoch
        wandb.log(log_data, step=step)

    return metrics


def get_clip_metrics(image_features, text_features, logit_scale):
    metrics = {}
    logits_per_image = (logit_scale * image_features @ text_features.t()).detach().cpu()
    logits_per_text = logits_per_image.t().detach().cpu()

    logits = {"image_to_text": logits_per_image, "text_to_image": logits_per_text}
    ground_truth = torch.arange(len(text_features)).view(-1, 1)

    for name, logit in logits.items():
        ranking = torch.argsort(logit, descending=True)
        preds = torch.where(ranking == ground_truth)[1]
        preds = preds.detach().cpu().numpy()
        metrics[f"{name}_mean_rank"] = preds.mean() + 1
        metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1
        for k in [1, 5, 10]:
            metrics[f"{name}_R@{k}"] = np.mean(preds < k)

    return metrics


def maybe_compute_generative_loss(model_out):
    if "logits" in model_out and "labels" in model_out:
        token_logits = model_out["logits"]
        token_labels = model_out["labels"]
        return F.cross_entropy(token_logits.permute(0, 2, 1), token_labels)