import os
import torch
import wandb
from accelerate import Accelerator
from labml import monit
from src.evaluation.utils import IMAGENET2012_CLASSES
from utils.dataset_mapper import DatasetMapper
from src.generate_synthetic_data.data import build_imagenet_filtered_dataset_and_class_indices


def batchify(data, batch_size):
    for i in range(0, len(data), batch_size):
        yield data[i:i + batch_size]


def validate_and_log(sampler, accelerator: Accelerator, mapper: DatasetMapper, data: torch.Tensor, labels: torch.Tensor, epoch: int, global_step: int, args, final: bool=False):
    save_dir = args.img_output
    os.makedirs(save_dir, exist_ok=True)
    generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
    batch_size = min(10, args.ipc)
    if torch.all(labels == labels[0]):
        class_name = mapper.get_class_name(labels[0].item(), simplified=False)
    else:
        raise ValueError('Batch labels are not the same, but they should be.')
    imagenet_id = None
    for id_, name in IMAGENET2012_CLASSES.items():
        if name == class_name:
            imagenet_id = id_
            break
    if imagenet_id is None:
        print(f"Warning: No ImageNet ID found for class '{class_name}', using original name")
        imagenet_id = class_name
    print(f'Start generating images - class: {class_name} (ID: {imagenet_id}), target count: {args.ipc}')
    prompts = [class_name] * args.ipc
    guidance_step_size = args.guidance_step_size
    time_travel = args.time_travel
    generated_images = []
    guidance_scale = args.guidance_scale
    num_inference_steps = args.num_inference_steps
    total_generated = 0
    for i, batch_prompts in enumerate(batchify(prompts, batch_size)):
        print(f'Processing batch {i + 1}, batch size: {len(batch_prompts)}')
        if args.model_type.lower() == 'dit':
            class_ids = []
            for prompt in batch_prompts:
                try:
                    single_prompt = prompt.split(',')[0].strip()
                    class_id = sampler.get_label_ids([single_prompt])[0]
                    class_ids.append(class_id)
                except:
                    print(f"Warning: '{prompt}' not found in ImageNet labels. Using a default class.")
                    class_ids.append(0)
            batch_images = sampler(class_labels=class_ids, reference_images=data, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, guidance_step_size=guidance_step_size, feature_layer_idx=args.feature_layer_idx, time_travel=time_travel, generator=generator).images
        elif args.model_type.lower() == 'sd':
            batch_images = sampler(prompt=batch_prompts, reference_images=data, num_inference_steps=num_inference_steps, negative_prompts='worst quality, normal quality, low quality, low res, blurry, text, watermark, logo, banner, extra digits, cropped, jpeg artifacts, signature, username, error, sketch, duplicate, ugly, monochrome, horror, geometry, mutation, disgusting', guidance_scale=guidance_scale, guidance_step_size=guidance_step_size, feature_layer_idx=args.feature_layer_idx, time_travel=time_travel, generator=generator).images
        if final:
            class_dir = os.path.join(save_dir, 'final', imagenet_id)
        os.makedirs(class_dir, exist_ok=True)
        for j, image in enumerate(batch_images):
            image_idx = i * batch_size + j
            image_path = os.path.join(class_dir, f"{imagenet_id}_{('final' if final else f'epoch{epoch + 1}')}_img{image_idx + 1}.png")
            image.save(image_path)
            total_generated += 1
            wandb_image = wandb.Image(image, caption=f'Epoch {epoch + 1} Step {global_step}: {class_name} ({imagenet_id}) Image {image_idx + 1}')
            generated_images.append(wandb_image)
    print(f'Finished generating images - class: {class_name}, actual count: {total_generated}')
    wandb.log({f'Epoch_{epoch + 1}/{imagenet_id}': generated_images, 'epoch': epoch + 1, 'global_step': global_step})
    print(f'Generated images for {imagenet_id}: {class_name}')
    torch.cuda.empty_cache()


def distill_dataset(sampler, train_dataloader, accelerator, args=None):
    if args.dataset.lower() == 'imagenet':
        mapper = DatasetMapper('imagenet', args.subset)
    else:
        mapper = DatasetMapper('cifar10')
    print(f"Using {args.dataset} {('subset ' + args.subset if args.subset else '')}")
    global_step = 0
    accelerator.wait_for_everyone()
    if args.dataset.lower() == 'imagenet':
        filtered_dataset, class_indices = build_imagenet_filtered_dataset_and_class_indices(args)
        for class_label, idx_list in class_indices.items():
            if not accelerator.is_main_process:
                continue
            latents_chunks = []
            chunk = 32
            with monit.section(f'Encode latents for class {class_label}', is_silent=not accelerator.is_main_process):
                for start in range(0, len(idx_list), chunk):
                    batch_idxs = idx_list[start:start + chunk]
                    images = [filtered_dataset[i][0] for i in batch_idxs]
                    batch = torch.stack(images).to(accelerator.device)
                    with torch.no_grad():
                        enc_out = sampler.vae.encode(batch)
                        if hasattr(enc_out, 'latent_dist'):
                            lat = enc_out.latent_dist.sample()
                        elif hasattr(enc_out, 'latents'):
                            lat = enc_out.latents
                        else:
                            raise RuntimeError('Unexpected VAE.encode output')
                        lat = lat * sampler.vae.config.scaling_factor
                    latents_chunks.append(lat.detach().cpu())
                    del batch
                    torch.cuda.empty_cache()
            class_latents = torch.cat(latents_chunks, dim=0)
            mapped_label = mapper.label_mapping.get(class_label, class_label)
            labels = torch.tensor([mapped_label])
            with monit.section('validation', is_silent=not accelerator.is_main_process):
                validate_and_log(sampler=sampler, accelerator=accelerator, mapper=mapper, data=class_latents, labels=labels, epoch=1, global_step=global_step, args=args, final=True)
    else:
        for step, (images, labels) in enumerate(train_dataloader, 1):
            accelerator.wait_for_everyone()
            if accelerator.is_main_process:
                final_flag = True
                labels = mapper.convert_labels(labels)
                with monit.section('validation', is_silent=not accelerator.is_main_process):
                    validate_and_log(sampler=sampler, accelerator=accelerator, mapper=mapper, data=images, labels=labels, epoch=1, global_step=global_step, args=args, final=final_flag)
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        wandb.finish()
    accelerator.end_training()