import json
import logging
import os
from copy import deepcopy
from pathlib import Path

import hydra
import torch
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration, set_seed
from dataset import CellDataModule, to_rgb
from diffusers.models import AutoencoderKL
from loss import SILoss
from metrics_utils import calculate_metrics_from_scratch
from models.sit import SiT_models
from omegaconf import DictConfig, OmegaConf
from tqdm.auto import tqdm
from utils.data_utils import preprocess_raw_image, process_perturbation_samples
from utils.generation_utils import (
    generate_and_process_samples,
    generate_and_process_samples_multi_celltype,
    generate_perturbation_matched_samples,
    process_latents_through_vae,
)
from utils.log_utils import create_logger, grid_image
from utils.model_utils import requires_grad, sample_posterior_2, update_ema
from utils.utils import load_encoders

logger = logging.getLogger(__name__)
CLIP_DEFAULT_MEAN = (0.481455, 0.457827, 0.408211)
CLIP_DEFAULT_STD = (0.26863, 0.261303, 0.275777)


@hydra.main(config_path="./cfgs", config_name="diffusion_sit_full", version_base=None)
def main(cfg: DictConfig) -> None:
    args = cfg
    # set accelerator
    logging_dir = Path(args.output_dir, args.logging_dir)
    accelerator_project_config = ProjectConfiguration(
        project_dir=args.output_dir, logging_dir=logging_dir
    )

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_config=accelerator_project_config,
    )

    # For FID-based checkpointing on the main process
    min_recorded_avg_fid = float("inf")  # Changed from min_recorded_fid

    if accelerator.is_main_process:
        os.makedirs(
            args.output_dir, exist_ok=True
        )  # Make results folder (holds all experiment subfolders)
        save_dir = os.path.join(args.output_dir, args.exp_name)
        os.makedirs(save_dir, exist_ok=True)
        args_dict = OmegaConf.to_container(args, resolve=True)
        # Save to a JSON file
        json_dir = os.path.join(save_dir, "args.json")
        with open(json_dir, "w") as f:
            json.dump(args_dict, f, indent=4)
        checkpoint_dir = f"{save_dir}/checkpoints"  # Stores saved model checkpoints
        os.makedirs(checkpoint_dir, exist_ok=True)
        logger = create_logger(save_dir)
        logger.info(f"Experiment directory created at {save_dir}")
    device = accelerator.device
    if torch.backends.mps.is_available():
        accelerator.native_amp = False
    if args.seed is not None:
        set_seed(args.seed + accelerator.process_index)

    # Create model:
    assert (
        args.resolution % 8 == 0
    ), "Image size must be divisible by 8 (for the VAE encoder)."
    latent_size = args.resolution // 8

    if args.enc_type != None:
        encoders, encoder_types, architectures = load_encoders(
            args.enc_type, device, args.resolution
        )
    else:
        raise NotImplementedError()
    z_dims = (
        [encoder.embed_dim for encoder in encoders] if args.enc_type != "None" else [0]
    )
    block_kwargs = {"fused_attn": args.fused_attn, "qk_norm": args.qk_norm}
    model = SiT_models[args.model](
        input_size=latent_size,
        num_classes=args.num_classes,
        use_cfg=(args.cfg_prob > 0),
        z_dims=z_dims,
        encoder_depth=args.encoder_depth,
        in_channels=args.in_channels,
        **block_kwargs,
    )

    model = model.to(device)
    ema = deepcopy(model).to(
        device
    )  # Create an EMA of the model for use after training
    vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-mse").to(device)
    requires_grad(ema, False)

    latents_scale = (
        torch.tensor([0.18215, 0.18215, 0.18215, 0.18215]).view(1, 4, 1, 1).to(device)
    )
    latents_bias = torch.tensor([0.0, 0.0, 0.0, 0.0]).view(1, 4, 1, 1).to(device)

    # create loss function
    loss_fn = SILoss(
        prediction=args.prediction,
        path_type=args.path_type,
        encoders=encoders,
        accelerator=accelerator,
        latents_scale=latents_scale,
        latents_bias=latents_bias,
        weighting=args.weighting,
    )
    if accelerator.is_main_process:
        logger.info(f"SiT Parameters: {sum(p.numel() for p in model.parameters()):,}")

    # Setup optimizer (we used default Adam betas=(0.9, 0.999) and a constant learning rate of 1e-4 in our paper):
    if args.allow_tf32:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    # Setup data:
    datamodule = CellDataModule(cfg)
    train_dataloader = datamodule.get_train_loader()
    if accelerator.is_main_process:
        logger.info(
            f"Dataset contains {len(train_dataloader.dataset):,} images ({args.data_dir})"
        )

    # Prepare models for training:
    update_ema(
        ema, model, accelerator, decay=0
    )  # Ensure EMA is initialized with synced weights
    model.train()  # important! This enables embedding dropout for classifier-free guidance
    ema.eval()  # EMA model should always be in eval mode

    # resume:
    global_step = 0
    if args.resume_step > 0:
        ckpt_name = str(args.resume_step).zfill(7) + ".pt"
        ckpt = torch.load(
            f"{os.path.join(args.output_dir, args.exp_name)}/checkpoints/{ckpt_name}",
            map_location="cpu",
        )
        model.load_state_dict(ckpt["model"])
        ema.load_state_dict(ckpt["ema"])
        optimizer.load_state_dict(ckpt["opt"])
        global_step = ckpt["steps"]
    ##########################################################################################
    # ckpt = torch.load(
    #     "/mnt/pvc/REPA/exps/Plain-ophenomdeneme-b-enc8-in512/checkpoints/min_FID_75.060546875.pt",
    #     map_location="cpu",
    #     weights_only=False,
    # )
    # model.load_state_dict(ckpt["model"])
    # ema.load_state_dict(ckpt["ema"])
    # optimizer.load_state_dict(ckpt["opt"])
    # global_step = ckpt["steps"]
    ##########################################################################################
    model, optimizer, train_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader
    )

    if accelerator.is_main_process:
        tracker_config = OmegaConf.to_container(args, resolve=True)
        accelerator.init_trackers(
            project_name=args.task_name,
            config=tracker_config,
            init_kwargs={"wandb": {"name": f"{args.exp_name}"}},
        )

    progress_bar = tqdm(
        range(0, args.max_train_steps),
        initial=global_step,
        desc="Steps",
        # Only show the progress bar once on each machine.
        disable=not accelerator.is_local_main_process,
    )

    # Labels to condition the model with (feel free to change):

    sample_batch_size = 1
    # Define selected perturbations for each process - do this early
    selected_perturbations = [1138, 1137, 1108, 1124, 375, 25, 1107, 966]
    process_index = accelerator.process_index
    selected_perturbation = selected_perturbations[
        process_index % len(selected_perturbations)
    ]

    # Define the fixed cell type ID
    fixed_cell_type = 1

    # Get ground truth samples corresponding to the selected perturbation and cell type if possible
    gt_found = False
    for batch_idx, (x, y, ct) in enumerate(train_dataloader):
        for i in range(len(y)):
            if y[i] == selected_perturbation and ct[i] == fixed_cell_type:
                gt_raw_images = x[i : i + 1]
                with torch.no_grad():
                    B, C, H, W = gt_raw_images.shape
                    gt_raw_images = gt_raw_images.view(B * C, 1, H, W)
                    gt_raw_images = gt_raw_images.repeat(
                        1, 3, 1, 1
                    )  # B*6 x 3 x 512 x 512
                    gt_raw_images = gt_raw_images * 2 - 1
                    gt_xs = vae.encode(gt_raw_images).latent_dist
                    gt_xs = sample_posterior_2(
                        gt_xs.mean, gt_xs.std, latents_scale, latents_bias
                    ).to(device)
                gt_found = True
                break
        if gt_found:
            break
    gt_xs = gt_xs.view(sample_batch_size, 6 * 4, 64, 64)  # B x 24 x 64 x 64
    # If we couldn't find an exact match, just take the first batch
    if not gt_found:
        gt_raw_images, gt_xs, _ = next(iter(train_dataloader))

    assert gt_raw_images.shape[-1] == args.resolution

    # Create fixed noise for this process
    fixed_noise = torch.randn(
        (1, args.in_channels, latent_size, latent_size), device=device
    )

    # Set up conditioning for fixed noise samples
    fixed_class_ids = torch.tensor([selected_perturbation], device=device)
    fixed_cell_type_ids = torch.tensor([fixed_cell_type], device=device)

    while True:
        model.train()
        for raw_image, y, ct in train_dataloader:
            # discard (y, ct) = (1137, 1) by masking the data
            mask = ~((y == 1137) & (ct == 1))
            # Apply the mask to filter the data
            raw_image = raw_image[mask]
            y = y[mask]
            ct = ct[mask]
            if len(raw_image) == 0:
                continue
            B, C, H, W = raw_image.shape
            raw_image = raw_image.to(device)
            x = raw_image.view(B * C, 1, H, W)
            x = x.repeat(1, 3, 1, 1)  # B*6 x 3 x 512 x 512

            y = y.to(device)
            ct = ct.to(device)
            z = None
            if args.legacy:
                # In our early experiments, we accidentally apply label dropping twice:
                # once in train.py and once in sit.py.
                # We keep this option for exact reproducibility with previous runs.
                drop_ids = torch.rand(y.shape[0], device=y.device) < args.cfg_prob
                labels = torch.where(drop_ids, args.num_classes, y)
            else:
                labels = y
            with torch.no_grad():
                # move x -> [-1, 1]
                x = x * 2 - 1
                x = vae.encode(x).latent_dist  # B*6 x 4 x 64 x 64
                x = sample_posterior_2(
                    x.mean,
                    x.std,
                    latents_scale=latents_scale,
                    latents_bias=latents_bias,
                )
                x = x.view(B, 6 * 4, 64, 64)  # B x 24 x 64 x 64

                zs = []
                with accelerator.autocast():
                    for encoder, encoder_type, arch in zip(
                        encoders, encoder_types, architectures
                    ):
                        raw_image_ = preprocess_raw_image(raw_image, encoder_type)
                        if "mocov3" in encoder_type:
                            z = encoder.forward_features(raw_image_)
                            z = z = z[:, 1:]
                        if "dinov2" in encoder_type:
                            z = encoder.forward_features(raw_image_)
                            z = z["x_norm_patchtokens"]
                        if "openphenom" in encoder_type:
                            z = encoder.forward_features(raw_image_)
                        zs.append(z)

            with accelerator.accumulate(model):
                model_kwargs = dict(y=labels, ct=ct)
                loss, proj_loss = loss_fn(model, x, model_kwargs, zs=zs)
                loss_mean = loss.mean()
                proj_loss_mean = proj_loss.mean()
                loss = loss_mean + proj_loss_mean * args.proj_coeff

                ## optimization
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    params_to_clip = model.parameters()
                    grad_norm = accelerator.clip_grad_norm_(
                        params_to_clip, args.max_grad_norm
                    )
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)

                if accelerator.sync_gradients:
                    update_ema(ema, model, accelerator)

            ### enter
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
            if global_step % args.checkpointing_steps == 0 and global_step > 0:
                if accelerator.is_main_process:
                    checkpoint = {
                        "model": model.module.state_dict(),
                        "ema": ema.state_dict(),
                        "opt": optimizer.state_dict(),
                        "args": args,
                        "steps": global_step,
                    }
                    checkpoint_path = f"{checkpoint_dir}/{global_step:07d}.pt"
                    torch.save(checkpoint, checkpoint_path)
                    logger.info(f"Saved checkpoint to {checkpoint_path}")

            if global_step == 1 or (
                global_step % args.sampling_steps == 0 and global_step > 0
            ):
                model.eval()

                # List of all cell types to generate samples for
                all_cell_types = [0, 1, 2, 3]

                # Generate samples with changing noise for each cell type and with the fixed noise
                with torch.no_grad():
                    # Generate samples using the fixed noise created earlier
                    samples_fixed = generate_and_process_samples(
                        model,
                        fixed_noise,
                        fixed_class_ids,
                        fixed_cell_type_ids,
                        vae,
                        latent_size,
                        args.resolution,
                        latents_bias,
                        latents_scale,
                        args.path_type,
                        C=6,
                        device=device,
                        heun=False,
                    )
                    # Create changing noise for this process (one per cell type)
                    changing_noise = torch.randn(
                        (1, args.in_channels, latent_size, latent_size), device=device
                    )

                    # Generate samples for each cell type using changing noise
                    samples_changing = generate_and_process_samples_multi_celltype(
                        model,
                        changing_noise,
                        fixed_class_ids,
                        all_cell_types,
                        vae,
                        latent_size,
                        args.resolution,
                        latents_bias,
                        latents_scale,
                        args.path_type,
                        C=6,
                        device=device,
                        heun=False,
                    )
                    # Process real samples for this perturbation for metrics calculation
                    perturbation_samples, perturbation_metadata = (
                        process_perturbation_samples(
                            datamodule,
                            selected_perturbation,
                            num_samples=100,
                            device=device,
                            accelerator=accelerator,
                        )
                    )
                    # Generate perturbation-matched samples
                    if perturbation_samples is not None:
                        generated_samples, generation_metadata = (
                            generate_perturbation_matched_samples(
                                model,
                                selected_perturbation,
                                perturbation_metadata,
                                vae,
                                latent_size,
                                args.resolution,
                                latents_bias,
                                latents_scale,
                                args.path_type,
                                device,
                            )
                        )
                        # Calculate metrics
                        if generated_samples is not None:
                            # Each process calculates metrics on its own samples
                            local_fid, local_kid, local_kid_sd = (
                                calculate_metrics_from_scratch(
                                    perturbation_samples,
                                    generated_samples,
                                    feature_extractor="inception_v3",
                                )
                            )
                            local_fod, local_kod, local_kod_sd = (
                                calculate_metrics_from_scratch(
                                    perturbation_samples,
                                    generated_samples,
                                    feature_extractor="openphenom",
                                )
                            )

                            # Prepare metrics for gathering
                            # Add batch dimension to ensure proper gathering
                            metrics_tensor = torch.tensor(
                                [[local_fid, local_kid, local_fod, local_kod]],
                                device=accelerator.device,
                            )

                            # Get the perturbation ID for each process
                            perturbation_id_tensor = torch.tensor(
                                [selected_perturbation], device=accelerator.device
                            )

                            # Gather metrics and perturbation IDs from all processes
                            gathered_metrics = accelerator.gather(metrics_tensor)
                            gathered_perturbations = accelerator.gather(
                                perturbation_id_tensor
                            )

                            # Only the main process logs the gathered metrics and handles FID-based checkpointing
                            if accelerator.is_main_process:
                                # Calculate average FID across all processes
                                all_fids = [
                                    gathered_metrics[i][0].item()
                                    for i in range(len(gathered_metrics))
                                ]
                                average_fid = (
                                    sum(all_fids) / len(all_fids)
                                    if len(all_fids) > 0
                                    else float("inf")
                                )
                                accelerator.log(
                                    {
                                        f"metrics/average_fid_across_processes": average_fid
                                    },
                                    step=global_step,
                                )

                                # FID-based checkpointing using average FID
                                if average_fid < min_recorded_avg_fid:
                                    min_recorded_avg_fid = average_fid
                                    logger.info(
                                        f"New minimum average FID: {min_recorded_avg_fid:.8f} at step {global_step}. Saving checkpoint."
                                    )

                                    checkpoint_name = (
                                        f"min_AVG_FID_{min_recorded_avg_fid:.8f}.pt"
                                    )
                                    new_best_fid_checkpoint_path = os.path.join(
                                        checkpoint_dir, checkpoint_name
                                    )

                                    checkpoint = {
                                        "model": accelerator.unwrap_model(
                                            model
                                        ).state_dict(),
                                        "ema": ema.state_dict(),
                                        "opt": optimizer.state_dict(),
                                        "args": args,
                                        "steps": global_step,
                                        "avg_fid": min_recorded_avg_fid,
                                    }
                                    torch.save(checkpoint, new_best_fid_checkpoint_path)
                                    logger.info(
                                        f"Saved FID-based checkpoint to {new_best_fid_checkpoint_path}"
                                    )

                                # Log each process's metrics with its perturbation ID
                                for i in range(len(gathered_metrics)):
                                    process_metrics = gathered_metrics[i]
                                    process_perturbation = int(
                                        gathered_perturbations[i].item()
                                    )

                                    accelerator.log(
                                        {
                                            f"metrics/perturbation_{process_perturbation}/fid": float(
                                                process_metrics[0].item()
                                            ),
                                            f"metrics/perturbation_{process_perturbation}/kid": float(
                                                process_metrics[1].item()
                                            ),
                                            f"metrics/perturbation_{process_perturbation}/fod": float(
                                                process_metrics[2].item()
                                            ),
                                            f"metrics/perturbation_{process_perturbation}/kod": float(
                                                process_metrics[3].item()
                                            ),
                                        },
                                        step=global_step,
                                    )

                # Gather samples across processes
                samples_fixed_gathered = accelerator.gather(
                    samples_fixed.to(torch.float32)
                )

                # Process the samples with changing noise and multiple cell types
                samples_changing_gathered = accelerator.gather(
                    samples_changing.to(torch.float32)
                ).squeeze()

                # Convert grayscale to RGB
                rgb_samples_fixed = torch.stack(
                    [
                        to_rgb(img.cpu()[None]).squeeze(0)
                        for img in samples_fixed_gathered
                    ]
                )

                rgb_samples_changing = torch.stack(
                    [
                        to_rgb(img.cpu()[None]).squeeze(0)
                        for img in samples_changing_gathered
                    ]
                )

                # Reshape samples_changing from [process*cell_types, 1, H, W] to [cell_types, processes, 3, H, W]
                num_cell_types = len(all_cell_types)
                num_processes = accelerator.num_processes
                rgb_samples_changing = rgb_samples_changing.reshape(
                    num_processes, num_cell_types, *rgb_samples_changing.shape[1:]
                )

                # Transpose to get [cell_types, processes, 3, H, W]
                rgb_samples_changing = rgb_samples_changing.transpose(1, 0)

                # Reshape to [cell_types*processes, 3, H, W] for visualization
                rgb_samples_changing = rgb_samples_changing.reshape(
                    -1, *rgb_samples_changing.shape[2:]
                )

                # Create captions
                fixed_noise_caption = f"Fixed noise samples - Process {process_index}, Pert {selected_perturbation}"
                multi_cell_caption = (
                    f"Multiple cell types for Pert {selected_perturbation}"
                )
                # Decode ground truth samples for visualization
                with torch.no_grad():
                    gt_samples = process_latents_through_vae(
                        gt_xs,
                        vae,
                        latent_size,
                        args.resolution,
                        latents_bias,
                        latents_scale,
                        C=6,
                    )
                    # gt_samples = vae.decode(
                    #     (gt_xs - latents_bias) / latents_scale
                    # ).sample
                    # gt_samples = (gt_samples + 1) / 2.0
                # Log the samples
                gt_samples = accelerator.gather(gt_samples.to(torch.float32))
                if accelerator.is_main_process:
                    rgb_gt_samples = torch.stack(
                        [to_rgb(img.cpu()[None]).squeeze(0) for img in gt_samples]
                    )
                    accelerator.log(
                        {
                            "gt_samples": grid_image(rgb_gt_samples, nrow=8),
                            "samples_fixed_noise": grid_image(
                                rgb_samples_fixed, nrow=8, caption=fixed_noise_caption
                            ),
                            "samples_changing_noise_multi_cell": grid_image(
                                rgb_samples_changing, nrow=8, caption=multi_cell_caption
                            ),
                        },
                        step=global_step,
                    )

                logging.info("Generating all sample types done.")
                model.train()

            logs = {
                "loss": accelerator.gather(loss_mean).mean().detach().item(),
                "proj_loss": accelerator.gather(proj_loss_mean).mean().detach().item(),
                "grad_norm": accelerator.gather(grad_norm).mean().detach().item(),
            }
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)

            if global_step >= args.max_train_steps:
                break
        if global_step >= args.max_train_steps:
            break

    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        logger.info("Done!")
    accelerator.end_training()


if __name__ == "__main__":
    main()
