import datetime
import os
import pathlib
from typing import Dict

import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import yaml
from scipy import linalg
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import Inception_V3_Weights, inception_v3
from torchvision.utils import make_grid
from tqdm import tqdm
from .inject_transform import InjectSignalTransform

import wandb

from .diffusion_utils import GaussianDiffusionSampler, GaussianDiffusionTrainer
from .scheduler import GradualWarmupScheduler
from .unet_model import UNet
from .mlp_model import MLPDenoiser
from .discretized_mlp_model import DiscretizedMLPDenoiser


def sample_images(model, device, modelConfig: Dict):
    """Generate sample images using the diffusion model"""
    # Set reproducible seed
    seed = modelConfig.get("eval_random_seed", 42)

    # Create sampler
    sampler = GaussianDiffusionSampler(
        model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]
    ).to(device)

    # Generate noise with fixed seed
    with torch.random.fork_rng():
        torch.manual_seed(seed)
        noise = torch.randn(
            [
                min(8, modelConfig["batch_size"]),
                modelConfig["in_channels"],
                modelConfig["img_size"],
                modelConfig["img_size"],
            ],
            device=device,
        )

    # Generate and normalize images
    return sampler(noise)


def load_model(modelConfig: Dict, device: torch.device):
    # model setup
    if modelConfig["model_type"] == "unet":
        net_model = UNet(
            T=modelConfig["T"],
            ch=modelConfig["channel"],
            ch_mult=modelConfig["channel_mult"],
            attn=modelConfig["attn"],
            num_res_blocks=modelConfig["num_res_blocks"],
            dropout=modelConfig["dropout"],
            in_channels=modelConfig["in_channels"],
            out_channels=modelConfig["out_channels"],
        ).to(device)
    elif modelConfig["model_type"] == "mlp":
        net_model = MLPDenoiser(
            num_blocks=modelConfig["num_blocks"],
            growth_rate=modelConfig["growth_rate"],
            dropout=modelConfig["dropout"],
            in_channels=modelConfig["in_channels"],
            out_channels=modelConfig["out_channels"],
            img_size=modelConfig["img_size"],
        ).to(device)
    else:  # discretized_mlp
        net_model = DiscretizedMLPDenoiser.from_config(modelConfig).to(device)

    if modelConfig["training_load_weight"] is not None:
        print(
            f"Loading model from {modelConfig['save_weight_dir']}/{modelConfig['training_load_weight']}"
        )
        weight_path = os.path.join(
            modelConfig["save_weight_dir"], modelConfig["training_load_weight"]
        )
        net_model.load_state_dict(
            torch.load(weight_path, map_location=device, weights_only=True)
        )
        print(f"Loaded the model")

    return net_model


def find_closest_images(generated_images, dataset, device):
    """Find the closest images in the dataset to each generated image using vectorized operations."""
    loader = DataLoader(dataset, batch_size=80, shuffle=False, num_workers=4)

    closest_images = [None] * len(generated_images)
    min_distances = torch.full((len(generated_images),), float("inf"), device=device)

    # Reshape generated images for broadcasting: [G, C, H, W] -> [G, 1, C, H, W]
    gen_imgs_expanded = generated_images.unsqueeze(1)

    # Process each batch of dataset images
    for batch in loader:
        if isinstance(batch, (tuple, list)):
            batch = batch[0]  # Handle (image, label) tuples
        batch = batch.to(device)  # [B, C, H, W]

        # Ensure same number of channels
        if batch.shape[1] != generated_images.shape[1]:
            if batch.shape[1] == 1:
                batch = batch.repeat(1, 3, 1, 1)
            else:
                batch = batch.mean(dim=1, keepdim=True)

        # Compute all pairwise distances with broadcasting:
        # [G, 1, C, H, W] - [1, B, C, H, W] -> [G, B, C, H, W]
        diffs = gen_imgs_expanded - batch.unsqueeze(0)
        batch_distances = torch.mean(diffs**2, dim=[2, 3, 4])
        batch_min_dists, batch_min_indices = torch.min(batch_distances, dim=1)

        # Update overall minimums where this batch has closer images
        update_mask = batch_min_dists < min_distances
        min_distances[update_mask] = batch_min_dists[update_mask]

        # Update closest images where needed
        for i in torch.where(update_mask)[0]:
            closest_images[i] = batch[batch_min_indices[i]].clone()

    return torch.stack(closest_images), min_distances


def calculate_fid(
    model, device, dataset, modelConfig: Dict, num_samples=100, batch_size=50
):
    # Create an inception model
    inception = inception_v3(
        weights=Inception_V3_Weights.IMAGENET1K_V1, transform_input=False
    )
    # Remove the final classification layer
    inception.fc = torch.nn.Identity()
    inception.eval()
    inception.to(device)

    # Create a dataloader for real images
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=False
    )

    # Create sampler for generating images
    sampler = GaussianDiffusionSampler(
        model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]
    ).to(device)

    # Function to extract features using Inception
    def get_features(images):
        # Handle grayscale images
        if images.shape[1] == 1:
            images = images.repeat(1, 3, 1, 1)

        # Resize images to 299x299 as required by Inception
        images = F.interpolate(
            images, size=(299, 299), mode="bilinear", align_corners=False
        )
        # Convert from [-1,1] to [0,1] range
        images = (images + 1) / 2
        # Get inception features
        features = inception(images).detach().cpu().numpy()
        return features

    # Extract features from real images
    real_features = []
    total_real_images = 0

    print("Extracting features from real images...")
    with torch.no_grad():
        for batch in tqdm(dataloader):
            if isinstance(batch, (tuple, list)):
                batch = batch[0]  # Handle (image, label) tuples
            images = batch.to(device)
            batch_features = get_features(images)
            real_features.append(batch_features)

            total_real_images += images.shape[0]
            if total_real_images >= num_samples:
                break

    real_features = np.concatenate(real_features, axis=0)[:num_samples]

    # Generate images and extract features
    fake_features = []
    total_generated = 0

    print("Generating images and extracting features...")
    with torch.no_grad():
        for i in tqdm(range(0, num_samples, batch_size)):
            current_batch_size = min(batch_size, num_samples - i)
            # Generate random noise
            noise = torch.randn(
                [
                    current_batch_size,
                    modelConfig["in_channels"],
                    modelConfig["img_size"],
                    modelConfig["img_size"],
                ],
                device=device,
            )
            # Generate images
            fake_images = sampler(noise)
            # Extract features
            batch_features = get_features(fake_images)
            fake_features.append(batch_features)

            total_generated += current_batch_size

    fake_features = np.concatenate(fake_features, axis=0)

    # Calculate mean and covariance for both distributions
    mu_real = np.mean(real_features, axis=0)
    sigma_real = np.cov(real_features, rowvar=False)

    mu_fake = np.mean(fake_features, axis=0)
    sigma_fake = np.cov(fake_features, rowvar=False)

    # Calculate FID score
    diff = mu_real - mu_fake

    # Product might be almost singular
    covmean, _ = linalg.sqrtm(sigma_real.dot(sigma_fake), disp=False)
    if not np.isfinite(covmean).all():
        print(
            "FID calculation produces singular product; adding epsilon to diagonal of cov matrices"
        )
        offset = np.eye(sigma_real.shape[0]) * 1e-6
        covmean = linalg.sqrtm((sigma_real + offset).dot(sigma_fake + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    fid_score = (
        diff.dot(diff) + np.trace(sigma_real) + np.trace(sigma_fake) - 2 * tr_covmean
    )

    return fid_score


def train(modelConfig: Dict, dataloader: DataLoader):
    # Set random seed for reproducibility
    seed = modelConfig.get("random_seed", 42)
    torch.manual_seed(seed)
    g = torch.Generator().manual_seed(seed)

    device = torch.device(modelConfig["device"])
    dataset_root = modelConfig["dataset_root"]
    checkpoint_freq = modelConfig["checkpoint_freq"]
    model_type = modelConfig["model_type"]
    use_attn = (
        "attn"
        if model_type == "unet" and len(modelConfig.get("attn", [])) > 0
        else "noattn"
    )

    injected_signal_power = modelConfig.pop("injected_signal_power", 0.0)

    # Save directory
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    save_name = f"{model_type}_{modelConfig['dataset_name']}_{modelConfig.get('subset_size', 'full')}_{use_attn}_{timestamp}"
    save_dir = os.path.join("trained_models", model_type, save_name)
    os.makedirs(save_dir, exist_ok=True)

    # Initialize wandb if enabled in config
    use_wandb = modelConfig.get("use_wandb", False)
    if use_wandb:
        wandb.init(
            project=f"{model_type}-{modelConfig['dataset_name']}",
            name=save_name,
            config=modelConfig,
        )

    # Save config as YAML
    config_path = os.path.join(save_dir, "config.yaml")
    with open(config_path, "w") as f:
        yaml.dump(modelConfig, f)

    # Load the model and optimizers
    net_model = load_model(modelConfig, device)
    total_params = sum(p.numel() for p in net_model.parameters())
    print(f"\n\n\nTotal number of parameters: {total_params:,}\n\n\n")

    optimizer = torch.optim.AdamW(
        net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4
    )
    cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1
    )
    warmUpScheduler = GradualWarmupScheduler(
        optimizer=optimizer,
        multiplier=modelConfig["multiplier"],
        warm_epoch=modelConfig["epoch"] // 10,
        after_scheduler=cosineScheduler,
    )
    trainer = GaussianDiffusionTrainer(
        net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]
    ).to(device)

    # Start training
    for e in range(modelConfig["epoch"]):
        epoch_loss = 0.0
        num_batches = 0

        with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
            for batch in tqdmDataLoader:
                # Handle (image, label) tuples
                if isinstance(batch, (tuple, list)):
                    batch = batch[0]

                # Optimization step
                optimizer.zero_grad()
                x_0 = batch.to(device)
                loss = trainer(x_0).sum() / 1000.0
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    net_model.parameters(), modelConfig["grad_clip"]
                )
                optimizer.step()

                # Track metrics
                epoch_loss += loss.item()
                num_batches += 1

                tqdmDataLoader.set_postfix(
                    ordered_dict={
                        "epoch": e,
                        "loss: ": loss.item(),
                        "img shape: ": x_0.shape,
                        "LR": optimizer.state_dict()["param_groups"][0]["lr"],
                    }
                )
                if use_wandb:
                    wandb.log(
                        {
                            "batch_loss": loss.item(),
                            "learning_rate": optimizer.state_dict()["param_groups"][0][
                                "lr"
                            ],
                        },
                        step=e * len(dataloader) + num_batches,
                    )

        warmUpScheduler.step()
        # Log epoch metrics to wandb
        if use_wandb and num_batches > 0:
            avg_epoch_loss = epoch_loss / num_batches
            wandb.log(
                {
                    "epoch": e,
                    "epoch_loss": avg_epoch_loss,
                },
                step=e * len(dataloader) + num_batches,
            )

            # Every few epochs, generate and log sample images
            if e > 0 and (e + 1) % modelConfig.get("sample_freq", 10) == 0:
                net_model.eval()
                with torch.no_grad():
                    sampled = sample_images(net_model, device, modelConfig)

                    # Find closest dataset images and get distances
                    closest_dataset_images, distances = find_closest_images(
                        sampled, dataloader.dataset, device
                    )
                    avg_distance = distances.mean().item()

                    comparison_grid = make_grid(
                        torch.cat(
                            [sampled * 0.5 + 0.5, closest_dataset_images * 0.5 + 0.5],
                            dim=0,
                        ),
                        nrow=len(sampled),
                        padding=2,
                    )

                    # Log to wandb
                    wandb.log(
                        {
                            "comparison": wandb.Image(
                                comparison_grid.detach().cpu(),
                                caption=f"Top: Generated, Bottom: Closest real image (epoch {e+1})",
                            ),
                            "avg_l2_distance": avg_distance,
                        },
                        step=e * len(dataloader) + num_batches,
                    )

                    # FID score
                    fid_score = calculate_fid(
                        net_model,
                        device,
                        dataloader.dataset,
                        modelConfig,
                        num_samples=100,
                        batch_size=100,
                    )
                    print(f"FID Score at epoch {e+1}: {fid_score:.4f}")
                    if use_wandb:
                        wandb.log(
                            {"fid_score": fid_score},
                            step=e * len(dataloader) + num_batches,
                        )

                net_model.train()

        # Save checkpoint based on checkpoint frequency
        if (e + 1) % checkpoint_freq == 0 or e == modelConfig["epoch"] - 1:
            checkpoint_path = os.path.join(save_dir, f"ckpt_epoch_{e+1}.pt")
            torch.save(net_model.state_dict(), checkpoint_path)
            print(f"Checkpoint saved at {checkpoint_path}")

    # Close wandb run
    if use_wandb:
        wandb.finish()
