import os

import torch
from dataset import to_rgb
from metrics_utils import calculate_metrics_from_scratch
from omegaconf import OmegaConf
from tqdm.auto import tqdm
from utils.checkpoint import save_checkpoint
from utils.data_utils import preprocess_raw_image, process_perturbation_samples
from utils.generation_utils import (
    generate_and_process_samples,
    generate_and_process_samples_multi_celltype,
    generate_perturbation_matched_samples,
    process_latents_through_vae,
)
from utils.log_utils import grid_image
from utils.model_utils import sample_posterior_2, update_ema


def train_loop(
    args,
    accelerator,
    model,
    ema,
    vae,
    encoders,
    optimizer,
    loss_fn,
    latents_scale,
    latents_bias,
    encoder_types,
    architectures,
    datamodule,
    train_dataloader,
    checkpoint_dir,
    logger,
    device,
    min_recorded_avg_fid,
):
    sample_batch_size = 1
    # selected_perturbations = [1138, 1137, 1108, 1124, 375, 25, 1107, 966]
    # selected_perturbations = [0, 1, 2, 3, 4, 5, 6, 7]  # Example perturbations
    selected_perturbations = [0, 1, 2, 3, 4, 5, 0, 1]  # Example perturbations
    process_index = accelerator.process_index
    selected_perturbation = selected_perturbations[
        process_index % len(selected_perturbations)
    ]
    # in_channels = model.in_channels
    in_channels = getattr(model, "module", model).in_channels

    fixed_cell_type = 1
    gt_found = False

    for batch_idx, (x, y, ct) in enumerate(train_dataloader):
        for i in range(len(y)):
            if y[i] == selected_perturbation and ct[i] == fixed_cell_type:
                gt_raw_images = x[i : i + 1]
                with torch.no_grad():
                    B, C, H, W = gt_raw_images.shape
                    gt_raw_images = gt_raw_images.view(B * C, 1, H, W)
                    gt_raw_images = gt_raw_images.repeat(1, 3, 1, 1)
                    gt_raw_images = gt_raw_images * 2 - 1
                    gt_xs = vae.encode(gt_raw_images).latent_dist
                    gt_xs = sample_posterior_2(
                        gt_xs.mean, gt_xs.std, latents_scale, latents_bias
                    ).to(device)
                gt_found = True
                break
        if gt_found:
            break
    latent_size = args.resolution // 8
    gt_xs = gt_xs.view(sample_batch_size, in_channels, latent_size, latent_size)
    if not gt_found:
        gt_raw_images, gt_xs, _ = next(iter(train_dataloader))
    assert gt_raw_images.shape[-1] == args.resolution
    fixed_noise = torch.randn(
        (1, args.in_channels, latent_size, latent_size), device=device
    )
    fixed_class_ids = torch.tensor([selected_perturbation], device=device)
    fixed_cell_type_ids = torch.tensor([fixed_cell_type], device=device)
    if accelerator.is_main_process:
        tracker_config = OmegaConf.to_container(args, resolve=True)
        accelerator.init_trackers(
            project_name=args.task_name,
            config=tracker_config,
            init_kwargs={"wandb": {"name": f"{args.exp_name}"}},
        )
    progress_bar = tqdm(
        range(0, args.max_train_steps),
        initial=0,
        desc="Steps",
        disable=not accelerator.is_local_main_process,
    )
    global_step = 0
    while True:
        model.train()
        for raw_image, y, ct in train_dataloader:
            B, C, H, W = raw_image.shape
            raw_image = raw_image.to(device)
            x = raw_image.view(B * C, 1, H, W)
            x = x.repeat(1, 3, 1, 1)
            y = y.to(device)
            ct = ct.to(device)
            z = None
            if args.legacy:
                drop_ids = torch.rand(y.shape[0], device=y.device) < args.cfg_prob
                labels = torch.where(drop_ids, args.num_classes, y)
            else:
                labels = y
            with torch.no_grad():
                x = x * 2 - 1
                x = vae.encode(x).latent_dist
                x = sample_posterior_2(
                    x.mean,
                    x.std,
                    latents_scale=latents_scale,
                    latents_bias=latents_bias,
                )
                x = x.view(B, in_channels, latent_size, latent_size)
                zs = []
                with accelerator.autocast():
                    for encoder, encoder_type, arch in zip(
                        encoders, encoder_types, architectures
                    ):
                        raw_image_ = preprocess_raw_image(raw_image, encoder_type)
                        if "mocov3" in encoder_type:
                            z = encoder.forward_features(raw_image_)
                            z = z[:, 1:]
                        if "dinov2" in encoder_type:
                            z = encoder.forward_features(raw_image_)
                            z = z["x_norm_patchtokens"]
                        if "openphenom" in encoder_type:
                            z = encoder.forward_features(raw_image_)
                        zs.append(z)
            with accelerator.accumulate(model):
                model_kwargs = dict(y=labels, ct=ct)
                loss, proj_loss = loss_fn(model, x, model_kwargs, zs=zs)
                loss_mean = loss.mean()
                proj_loss_mean = proj_loss.mean()
                loss = loss_mean + proj_loss_mean * args.proj_coeff
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    params_to_clip = model.parameters()
                    grad_norm = accelerator.clip_grad_norm_(
                        params_to_clip, args.max_grad_norm
                    )
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
                if accelerator.sync_gradients:
                    update_ema(ema, model, accelerator)
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
            if global_step % args.checkpointing_steps == 0 and global_step > 0:
                if accelerator.is_main_process:
                    save_checkpoint(
                        model,
                        ema,
                        optimizer,
                        args,
                        global_step,
                        checkpoint_dir,
                        accelerator,
                    )
                    logger.info(
                        f"Saved checkpoint to {checkpoint_dir}/{global_step:07d}.pt"
                    )
            if global_step == 1 or (
                global_step % args.sampling_steps == 0 and global_step > 0
            ):
                model.eval()
                all_cell_types = [0, 1, 2, 3]
                with torch.no_grad():
                    samples_fixed = generate_and_process_samples(
                        model,
                        fixed_noise,
                        fixed_class_ids,
                        fixed_cell_type_ids,
                        vae,
                        latent_size,
                        args.resolution,
                        latents_bias,
                        latents_scale,
                        args.path_type,
                        C=6,
                        device=device,
                        heun=False,
                    )
                    changing_noise = torch.randn(
                        (1, args.in_channels, latent_size, latent_size), device=device
                    )
                    samples_changing = generate_and_process_samples_multi_celltype(
                        model,
                        changing_noise,
                        fixed_class_ids,
                        all_cell_types,
                        vae,
                        latent_size,
                        args.resolution,
                        latents_bias,
                        latents_scale,
                        args.path_type,
                        C=6,
                        device=device,
                        heun=False,
                    )
                    perturbation_samples, perturbation_metadata = (
                        process_perturbation_samples(
                            datamodule,
                            selected_perturbation,
                            num_samples=100,
                            device=device,
                            accelerator=accelerator,
                        )
                    )
                    if perturbation_samples is not None:
                        generated_samples, generation_metadata = (
                            generate_perturbation_matched_samples(
                                model,
                                selected_perturbation,
                                perturbation_metadata,
                                vae,
                                latent_size,
                                args.resolution,
                                latents_bias,
                                latents_scale,
                                args.path_type,
                                device,
                            )
                        )
                        if generated_samples is not None:
                            local_fid, local_kid, local_kid_sd = (
                                calculate_metrics_from_scratch(
                                    perturbation_samples,
                                    generated_samples,
                                    feature_extractor="inception_v3",
                                )
                            )
                            local_fod, local_kod, local_kod_sd = (
                                calculate_metrics_from_scratch(
                                    perturbation_samples,
                                    generated_samples,
                                    feature_extractor="openphenom",
                                )
                            )
                            metrics_tensor = torch.tensor(
                                [[local_fid, local_kid, local_fod, local_kod]],
                                device=accelerator.device,
                            )
                            perturbation_id_tensor = torch.tensor(
                                [selected_perturbation], device=accelerator.device
                            )
                            gathered_metrics = accelerator.gather(metrics_tensor)
                            gathered_perturbations = accelerator.gather(
                                perturbation_id_tensor
                            )
                            if accelerator.is_main_process:
                                all_fids = [
                                    gathered_metrics[i][0].item()
                                    for i in range(len(gathered_metrics))
                                ]
                                average_fid = (
                                    sum(all_fids) / len(all_fids)
                                    if len(all_fids) > 0
                                    else float("inf")
                                )
                                accelerator.log(
                                    {
                                        f"metrics/average_fid_across_processes": average_fid
                                    },
                                    step=global_step,
                                )
                                if average_fid < min_recorded_avg_fid:
                                    min_recorded_avg_fid = average_fid
                                    logger.info(
                                        f"New minimum average FID: {min_recorded_avg_fid:.8f} at step {global_step}. Saving checkpoint."
                                    )
                                    checkpoint_name = (
                                        f"min_AVG_FID_{min_recorded_avg_fid:.8f}.pt"
                                    )
                                    new_best_fid_checkpoint_path = os.path.join(
                                        checkpoint_dir, checkpoint_name
                                    )
                                    save_checkpoint(
                                        model,
                                        ema,
                                        optimizer,
                                        args,
                                        global_step,
                                        new_best_fid_checkpoint_path,
                                        accelerator,
                                        min_recorded_avg_fid,
                                    )
                                    logger.info(
                                        f"Saved FID-based checkpoint to {new_best_fid_checkpoint_path}"
                                    )
                                for i in range(len(gathered_metrics)):
                                    process_metrics = gathered_metrics[i]
                                    process_perturbation = int(
                                        gathered_perturbations[i].item()
                                    )
                                    accelerator.log(
                                        {
                                            f"metrics/perturbation_{process_perturbation}/fid": float(
                                                process_metrics[0].item()
                                            ),
                                            f"metrics/perturbation_{process_perturbation}/kid": float(
                                                process_metrics[1].item()
                                            ),
                                            f"metrics/perturbation_{process_perturbation}/fod": float(
                                                process_metrics[2].item()
                                            ),
                                            f"metrics/perturbation_{process_perturbation}/kod": float(
                                                process_metrics[3].item()
                                            ),
                                        },
                                        step=global_step,
                                    )
                    samples_fixed_gathered = accelerator.gather(
                        samples_fixed.to(torch.float32)
                    )
                    samples_changing_gathered = accelerator.gather(
                        samples_changing.to(torch.float32)
                    ).squeeze()
                    rgb_samples_fixed = torch.stack(
                        [
                            to_rgb(img.cpu()[None]).squeeze(0)
                            for img in samples_fixed_gathered
                        ]
                    )
                    rgb_samples_changing = torch.stack(
                        [
                            to_rgb(img.cpu()[None]).squeeze(0)
                            for img in samples_changing_gathered
                        ]
                    )
                    num_cell_types = len(all_cell_types)
                    num_processes = accelerator.num_processes
                    rgb_samples_changing = rgb_samples_changing.reshape(
                        num_processes, num_cell_types, *rgb_samples_changing.shape[1:]
                    )
                    rgb_samples_changing = rgb_samples_changing.transpose(1, 0)
                    rgb_samples_changing = rgb_samples_changing.reshape(
                        -1, *rgb_samples_changing.shape[2:]
                    )
                    fixed_noise_caption = f"Fixed noise samples - Process {process_index}, Pert {selected_perturbation}"
                    multi_cell_caption = (
                        f"Multiple cell types for Pert {selected_perturbation}"
                    )
                    with torch.no_grad():
                        gt_samples = process_latents_through_vae(
                            gt_xs,
                            vae,
                            latent_size,
                            args.resolution,
                            latents_bias,
                            latents_scale,
                            C=6,
                        )
                    gt_samples = accelerator.gather(gt_samples.to(torch.float32))
                    if accelerator.is_main_process:
                        rgb_gt_samples = torch.stack(
                            [to_rgb(img.cpu()[None]).squeeze(0) for img in gt_samples]
                        )
                        accelerator.log(
                            {
                                "gt_samples": grid_image(rgb_gt_samples, nrow=8),
                                "samples_fixed_noise": grid_image(
                                    rgb_samples_fixed,
                                    nrow=8,
                                    caption=fixed_noise_caption,
                                ),
                                "samples_changing_noise_multi_cell": grid_image(
                                    rgb_samples_changing,
                                    nrow=8,
                                    caption=multi_cell_caption,
                                ),
                            },
                            step=global_step,
                        )
                    logger.info("Generating all sample types done.")
                    model.train()
            logs = {
                "loss": accelerator.gather(loss_mean).mean().detach().item(),
                "proj_loss": accelerator.gather(proj_loss_mean).mean().detach().item(),
                "grad_norm": accelerator.gather(grad_norm).mean().detach().item(),
            }
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            if global_step >= args.max_train_steps:
                break
        if global_step >= args.max_train_steps:
            break
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        logger.info("Done!")
    accelerator.end_training()
