import random
import logging
import os
import copy
import time
from hydra.utils import to_absolute_path
from omegaconf import OmegaConf
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.utils import save_image
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore


def set_seeds(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    logging.info(f"Random seeds set to {seed}")


def _build_ckpt_dir(base_dir, dataset, algo_name, model_name, seed):
    base_dir = os.path.expanduser(base_dir)
    if not os.path.isabs(base_dir):
        base_dir = to_absolute_path(base_dir)

    timestamp = time.strftime("%Y%m%d-%H%M%S")
    run_dir = f"{model_name}-{dataset}"
    run_name = f"{algo_name}-seed{seed}-{timestamp}"
    checkpoint_dir = os.path.join(base_dir, run_dir, run_name)
    os.makedirs(checkpoint_dir, exist_ok=True)
    return checkpoint_dir


def save_model(model, model_save_dir, dataset, algo_name, model_name, seed, cfg, filename="generator.pt"):
    checkpoint_dir = _build_ckpt_dir(model_save_dir, dataset, algo_name, model_name, seed)
    model_path = os.path.join(checkpoint_dir, filename)
    torch.save(model.state_dict(), model_path)
    OmegaConf.save(config=cfg, f=os.path.join(checkpoint_dir, "config.yaml"))
    logging.info(f"Model saved to {model_path}")


def get_model(build_fn, checkpoint_path, device=None):
    model = build_fn()
    if device is not None:
        model = model.to(device)

    state = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state)
    model.eval()
    return model


def evaluate_metrics(
        generator,
        dataloader: torch.utils.data.DataLoader,
        batch_size: int,
        z_dim: int,
        fid: Optional[FrechetInceptionDistance] = None,
        inception_score: Optional[InceptionScore] = None,
        num_samples: Optional[int] = None
    ):
    """
    Calculates GAN metrics (FID and/or IS).
    
    Args:
        generator: The generator model.
        dataloader: DataLoader for real images (should be test set).
        batch_size: Batch size for generating fake images.
        z_dim: Latent dimension for generator.
        fid: The FrechetInceptionDistance metric object.
        inception_score: The InceptionScore metric object.
        num_samples: Number of samples to use for calculation. If None, set to the length of the dataloader.
        
    Returns:
        dict: A dictionary containing 'fid' and/or 'is_mean', 'is_std'.
    """
    device = next(generator.parameters()).device
    generator.eval()
    
    if fid:
        fid.reset()
    if inception_score:
        inception_score.reset()
        
    # Process real images for FID
    real_count = 0
    if fid:
        with torch.no_grad():
            for real_batch, _ in dataloader:
                real_batch = real_batch.to(device)

                # Preprocess: (-1, 1) -> (0, 1) -> 3 channels
                real_rgb = torch.clamp((real_batch + 1) / 2.0, 0, 1)
                if real_rgb.shape[1] == 1:
                    real_rgb = real_rgb.repeat(1, 3, 1, 1)

                fid.update(real_rgb, real=True)
                real_count += real_rgb.size(0)

                if num_samples is not None and real_count >= num_samples:
                    break
        
        if num_samples is not None and real_count < num_samples:
            logging.debug(
                f"Provided {real_count} real samples, which is less than {num_samples} required for FID.\n"
                "FID score calculation will generate fake samples to match the number of real samples."
            )
            
    # Determine how many fake samples to generate
    if num_samples is not None:
        target_num_samples = num_samples
        if fid:
            target_num_samples = min(num_samples, real_count)
    else:
        if fid:
            target_num_samples = real_count # Use number of real samples collected
        else:
            # Estimate dataset length for IS if FID not used
            target_num_samples = len(dataloader.dataset)

    fake_count = 0
    with torch.no_grad():
        while fake_count < target_num_samples:
            current_batch_size = min(batch_size, target_num_samples - fake_count)
            z = torch.randn(current_batch_size, z_dim, device=device)
            fake_imgs = generator(z)

            if isinstance(fake_imgs, list):
                fake_imgs = fake_imgs[0]
            
            # Preprocess: (-1, 1) -> (0, 1) -> 3 channels
            fake_rgb = torch.clamp((fake_imgs + 1) / 2.0, 0, 1)
            if fake_rgb.shape[1] == 1:
                fake_rgb = fake_rgb.repeat(1, 3, 1, 1)
            
            if fid:
                fid.update(fake_rgb, real=False)
            if inception_score:
                inception_score.update(fake_rgb)
                
            fake_count += current_batch_size
            
    results = {}
    if fid:
        results['fid'] = fid.compute().item()
        fid.reset()
    if inception_score:
        is_mean, is_std = inception_score.compute()
        results['is_mean'] = is_mean.item()
        results['is_std'] = is_std.item()
        inception_score.reset()
    
    return results


def save_image_grid(generator, noise_z, out_path):
    """
    Saves a grid of generated images.
    """
    save_dir = os.path.dirname(out_path)
    if not os.path.exists(save_dir):
        os.makedirs(save_dir, exist_ok=True)
    
    generator.eval()
    with torch.no_grad():
        # Generate images and normalize to (0, 1)
        gen_imgs = generator(noise_z)

        if isinstance(gen_imgs, list):
            gen_imgs = gen_imgs[0]
        
        gen_imgs = torch.clamp((gen_imgs + 1) / 2.0, 0, 1)

        # Save as a grid
        num_samples = noise_z.size(0)
        nrow = int(num_samples ** 0.5)
        save_image(gen_imgs, out_path, nrow=nrow, normalize=True)


class ModelEMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = copy.deepcopy(model)
        self.shadow.eval()
        for param in self.shadow.parameters():
            param.requires_grad = False
        
        self.shadow.to(next(model.parameters()).device)

    def update(self, model):
        with torch.no_grad():
            msd = model.state_dict()
            ssd = self.shadow.state_dict()
            
            for key in msd:
                if msd[key].dtype in [torch.float16, torch.float32]:
                    # shadow = decay * shadow + (1 - decay) * new_param
                    ssd[key].data.mul_(self.decay)
                    ssd[key].data.add_(msd[key].data * (1.0 - self.decay))
                else:
                    ssd[key].data.copy_(msd[key].data)

# ---- DiffAugment ----

def rand_brightness(x):
    x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
    return x

def rand_saturation(x):
    x_mean = x.mean(dim=1, keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2) + x_mean
    return x

def rand_contrast(x):
    x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
    x = (x - x_mean) * (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5) + x_mean
    return x

def rand_translation(x, ratio=0.125):
    B, C, H, W = x.size()
    shift_x, shift_y = int(W * ratio + 0.5), int(H * ratio + 0.5)
    
    translation_x = torch.rand(B, device=x.device) * (2 * shift_x / W) - (shift_x / W) # [-ratio, ratio] approx
    translation_y = torch.rand(B, device=x.device) * (2 * shift_y / H) - (shift_y / H)
    
    theta = torch.zeros(B, 2, 3, device=x.device)
    theta[:, 0, 0] = 1
    theta[:, 1, 1] = 1
    theta[:, 0, 2] = translation_x
    theta[:, 1, 2] = translation_y
    
    grid = F.affine_grid(theta, x.size(), align_corners=False)
    x = F.grid_sample(x, grid, padding_mode='reflection', align_corners=False)
    return x

def rand_rotation(x):
    angle = (torch.rand(x.size(0), device=x.device) - 0.5) * 20 * np.pi / 180 # [-10, 10] degrees
    theta = torch.zeros(x.size(0), 2, 3, device=x.device)
    theta[:, 0, 0] = torch.cos(angle)
    theta[:, 0, 1] = -torch.sin(angle)
    theta[:, 1, 0] = torch.sin(angle)
    theta[:, 1, 1] = torch.cos(angle)
    
    grid = F.affine_grid(theta, x.size(), align_corners=False)
    x = F.grid_sample(x, grid, padding_mode='reflection', align_corners=False)
    return x

def rand_cutout(x, ratio=0.2):
    cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
    offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)
    offset_y = torch.randint(0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device)
    grid_batch, grid_x, grid_y = torch.meshgrid(
        torch.arange(x.size(0), dtype=torch.long, device=x.device),
        torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
        torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
        indexing='ij'
    )
    grid_x = torch.clamp(grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1)
    grid_y = torch.clamp(grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1)
    mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
    mask[grid_batch, grid_x, grid_y] = 0
    x = x * mask.unsqueeze(1)
    return x

AUGMENT_FNS = {
    'color': [rand_brightness, rand_saturation, rand_contrast],
    'translation': [rand_translation],
    'cutout': [rand_cutout],
    'rotation': [rand_rotation],
}
DEFAULT_POLICY = 'cutout,translation'


def DiffAugment(x, policy=DEFAULT_POLICY, channels_first=True, prob=0.5):
    if policy:
        if not channels_first:
            x = x.permute(0, 3, 1, 2)
        for p in policy.split(','):
            for f in AUGMENT_FNS[p]:
                if torch.rand(1).item() < prob:
                    x = f(x)
        
        if 'color' in policy:
            x = torch.clamp(x, -1, 1)
            
        if not channels_first:
            x = x.permute(0, 2, 3, 1)
        x = x.contiguous()
    return x
