import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Subset
from tqdm import tqdm
from pathlib import Path
import torchvision.transforms.functional as TF
from torchvision.utils import make_grid
import random
from torchmetrics.image.fid import FrechetInceptionDistance
from copy import deepcopy
import os
import json
import argparse

def get_cl_dataset(name='mnist', batch_size=64, normalize=True, greyscale=False, group_size=2, n_classes=10):
    # Anonymous, configurable locations
    data_root = os.environ.get('DATA_ROOT', './data')
    hf_cache = os.environ.get('HF_DATASETS_CACHE', './cache')
    if name.lower() == 'mnist':
        transform = transforms.Compose([transforms.ToTensor()])
        if normalize:
            transform = transforms.Compose([
                transforms.Pad(2),  # Padding to make it 32x32
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,))
            ])
        train_dataset = datasets.MNIST(root=data_root, train=True, download=True, transform=transform)
        test_dataset = datasets.MNIST(root=data_root, train=False, download=True, transform=transform)
    elif name.lower() == 'fmnist':
        transform = transforms.Compose([transforms.ToTensor()])
        if normalize:
            transform = transforms.Compose([
                transforms.Pad(2),  # Padding to make it 32x32
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,))
            ])
        train_dataset = datasets.FashionMNIST(root=data_root, train=True, download=True, transform=transform)
        test_dataset = datasets.FashionMNIST(root=data_root, train=False, download=True, transform=transform)
    elif name.lower() == 'cifar10':
        transform = transforms.Compose([transforms.ToTensor()])
        if normalize:
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
        if greyscale:
            transform = transforms.Compose([
                transforms.Grayscale(num_output_channels=1),
                transforms.ToTensor(),
                transforms.Normalize((0.5,), (0.5,))
            ])
        train_dataset = datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform)
        test_dataset = datasets.CIFAR10(root=data_root, train=False, download=True, transform=transform)

    elif name.lower() == 'imagenet32':
        from datasets import load_dataset  # HF datasets (doesn't shadow torchvision.datasets)
        hf_repo = "benjamin-paine/imagenet-1k-32x32"   # alt: "sradc/imagenet_resized_64x64"
        # Base transform (add normalize if requested)
        transform = transforms.Compose([transforms.ToTensor()])
        if normalize:
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])
        train_hf = load_dataset(hf_repo, split="train", cache_dir=hf_cache)
        test_hf = load_dataset(hf_repo, split="validation", cache_dir=hf_cache)

        # ----- Choose a class subset if requested -----
        # NOTE: set a seed for reproducibility if you like (e.g., random.seed(0))
        if n_classes is not None and n_classes < 1000:
            # unique returns a Python list of label ids present in the split
            all_train_labels = sorted(train_hf.unique("label"))
            assert len(all_train_labels) == 1000, "Expected 1000 ImageNet-1K classes in train split."

            # Sample a subset ONCE and reuse for both splits
            # (set a fixed seed here for deterministic subsets)
            # random.seed(0)
            chosen_labels = sorted(random.sample(all_train_labels, n_classes))
            chosen_set = set(chosen_labels)

            # Build a compact label map  old_label -> new_label in [0, n_classes-1]
            label_map = {old: new for new, old in enumerate(chosen_labels)}

            # Filter both splits to the same chosen labels
            def keep_subset(example):
                return example["label"] in chosen_set

            train_hf = train_hf.filter(keep_subset)
            test_hf  = test_hf.filter(keep_subset)

            # Remap labels to 0..n_classes-1
            def remap_label(example):
                example["label"] = label_map[int(example["label"])]
                return example

            train_hf = train_hf.map(remap_label)
            test_hf  = test_hf.map(remap_label)

            effective_num_classes = n_classes
        else:
            # No subsetting; keep original labels 0..999
            label_map = None
            effective_num_classes = 1000

        # (transform defined above)
        # ----- Torch dataset wrapper -----
        class HFImageNet64(torch.utils.data.Dataset):
            def __init__(self, hf_ds, transform, num_classes):
                self.ds = hf_ds
                self.transform = transform
                self.num_classes = int(num_classes)

            def __len__(self):
                return len(self.ds)

            def __getitem__(self, idx):
                rec = self.ds[int(idx)]
                img = rec["image"]   # PIL.Image
                y   = int(rec["label"])  # already remapped if subset chosen
                x   = self.transform(img)
                return x, y

        train_dataset = HFImageNet64(train_hf, transform, effective_num_classes)
        test_dataset  = HFImageNet64(test_hf,  transform, effective_num_classes)
        
    else:
        train_dataset = None
        test_dataset = None

    n_groups   = n_classes // group_size  # == 5


    train_indices_per_group = {g: [] for g in range(n_groups)}
    print("Building DataLoaders for each class in train dataset...")
    for idx, (_, label) in enumerate(tqdm(train_dataset)):
        g = label // group_size
        train_indices_per_group[g].append(idx)

    # 3) Build one DataLoader per class
    train_loaders = {}
    # for class_id, indices in sorted(train_indices_per_class.items()):
    for g, indices in sorted(train_indices_per_group.items()):
        subset = Subset(train_dataset, indices)
        # train_loaders[class_id] = DataLoader(
        train_loaders[g] = DataLoader(
            subset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=4,    # adjust as needed
            pin_memory=False
        )

    # test_indices_per_class = {i: [] for i in range(10)}
    test_indices_per_group = {g: [] for g in range(n_groups)}
    print("Building DataLoaders for each class in MNIST test dataset...")
    # for idx, (_, label) in enumerate(tqdm(test_dataset)):
        # test_indices_per_class[label].append(idx)
    for idx, (_, label) in enumerate(tqdm(test_dataset)):
        g = label // group_size
        test_indices_per_group[g].append(idx)
    # 3) Build one DataLoader per class
    test_loaders = {}
    # for class_id, indices in sorted(test_indices_per_class.items()):
    for g, indices in sorted(test_indices_per_group.items()):
        subset = Subset(test_dataset, indices)
        # test_loaders[class_id] = DataLoader(
        test_loaders[g] = DataLoader(
            subset,
            batch_size=512,
            shuffle=True,
            num_workers=4,    # adjust as needed
            pin_memory=False
        )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=False)
    test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False, num_workers=4, pin_memory=False)
    print(f"Dataset {name}: {len(train_dataset)} train samples, {len(test_dataset)} test samples.")
    return train_loaders, test_loaders, train_loader, test_loader

def load_config_from_json(config_path):
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"Configuration file not found: {config_path}")

    try:
        with open(config_path, 'r') as f:
            config_dict = json.load(f)
    except json.JSONDecodeError as e:
        raise ValueError(f"Error decoding JSON from {config_path}: {e}")

    config_args = argparse.Namespace(**config_dict)

    return config_args

def train_one_task(model, train_loader, class_id, optimizer, 
                   ewc=None,
                   gr=None,
                   kl=False,
                   num_epochs=10, save_path=None, device='cuda', wandb=None):
    unique_labels = set()
    for epoch in tqdm(range(num_epochs)):
        for batch in tqdm(train_loader):
            images, labels = batch
            unique_labels.update(labels.tolist())
            images = images.to(device)
            labels = labels.to(device)

            if gr is not None:
                # combine with generated old data
                x_old, y_old = gr.replay()
                images = torch.cat([images, x_old], dim=0)
                labels = torch.cat([labels, y_old], dim=0)


            optimizer.zero_grad()
            loss = 0
            timesteps, noise, noisy_images, model_pred = model.diffusion_loss(images, labels)
            if gr is not None:
                replay_size = x_old.size(0)
                t_replay = timesteps[-replay_size:] # the second half are replayed samples
                noise_replay = noise[-replay_size:]
                noisy_images_replay = noisy_images[-replay_size:]
                model_pred_replay = model_pred[-replay_size:]

                t_batch = timesteps[:-replay_size]
                noise_batch = noise[:-replay_size]
                noisy_images_batch = noisy_images[:-replay_size]
                model_pred_batch = model_pred[:-replay_size]

                ddim_loss = F.mse_loss(model_pred_batch, noise_batch, reduction="mean")
            else:
                ddim_loss = F.mse_loss(model_pred, noise, reduction="mean")



            loss = loss + ddim_loss
            if ewc is not None:
                loss_ewc = ewc.loss(model)#.penalty() if ewc is not None else torch.zeros((), device=device)
                loss = loss + model.ewc_lambda * loss_ewc

            if kl and gr is not None:
                with torch.no_grad():
                    eps_teacher = gr.teacher.unet(noisy_images_replay, t_replay, y_old).sample
                eps_student = model_pred_replay
                loss_kl = F.mse_loss(eps_student, eps_teacher)
                loss = loss + model.gr_kl * loss_kl

            loss.backward()
            optimizer.step()

        if wandb is not None:
            wandb.log({
                'loss/ddim': ddim_loss.item(),
                'loss/ewc': loss_ewc.item() if ewc is not None else 0.0,
                'loss/kl': loss_kl.item() if (kl and gr is not None) else 0.0,
                'loss/total': loss.item(),
                'epoch': epoch + num_epochs * class_id,
            })


        # visualize every 50 epochs
        if save_path is not None and epoch % 50 == 0:
            out_dir = Path(save_path) / f"task_{class_id}" / f"epoch_{epoch:05d}"
            out_dir.mkdir(parents=True, exist_ok=True)
            cols = 8
            all_tensors = []
            for c in unique_labels:
                pils = model.sample(
                    batch_size=cols,
                    labels=[c] * cols,
                    num_inference_steps=50,
                    device=device,
                    guidance_scale=0.0,  # pure conditional
                )
                for im in pils:
                    # to [C,H,W] float in [0,1]
                    all_tensors.append((im + 1.0) * 0.5)  # [0,1] float

            grid = make_grid(torch.stack(all_tensors, dim=0), nrow=cols, padding=2)  # 8 per row
            grid_pil = TF.to_pil_image(grid.clamp(0, 1)) # is grid_pil a PIL image? Yes, it is.
            if wandb is not None:
                wandb.log({f"samples/task{class_id}": wandb.Image(grid_pil, caption=f"Task {class_id} Epoch {epoch}")})
            out_file = out_dir / f"epoch_{epoch:05d}_grid.png"
            grid_pil.save(out_file)



class FIDEvaluator:
    def __init__(self, device=None):
        self.device = torch.device(device) if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # normalize=True => inputs should be float in [0,1]
        self.fid = FrechetInceptionDistance(feature=2048).to(self.device)
        self.to_tensor = transforms.ToTensor()

    @staticmethod
    def _to01_and_rgb(x: torch.Tensor) -> torch.Tensor:
        """x: (B,C,H,W) float in [0,1] or [-1,1]; returns float in [0,1] with 3 channels."""
        if x.dtype.is_floating_point:
            if x.min() < 0.0:  # convert [-1,1] -> [0,1]
                x = (x + 1.0) * 0.5
        else:
            x = x.float() / 255.0
        if x.size(1) == 1:
            x = x.repeat(1, 3, 1, 1)
        return x.clamp(0.0, 1.0).to(torch.uint8)

    @torch.no_grad()
    def fid_loader_vs_model(
        self,
        real_loader,
        model,
        num_inference_steps: int = 50,
        seed: int | None = 123,
        max_real: int | None = None,       # optional cap on #real images processed
    ) -> float:
        self.fid.reset()
        dev = self.device
        seed_base = seed if seed is not None else 0
        bidx = 0
        seen = 0
        for imgs_real, labels_real in real_loader:
            if max_real is not None and seen >= max_real:
                break
            if max_real is not None and seen + imgs_real.size(0) > max_real:
                keep = max_real - seen
                imgs_real = imgs_real[:keep]
                labels_real = labels_real[:keep]

            # real -> [0,1], 3ch
            imgs_real = imgs_real.to(dev)
            imgs_real = (imgs_real + 1.0) * 127.5
            imgs_real = imgs_real.to(torch.uint8)
            if imgs_real.size(1) == 1:
                imgs_real = imgs_real.repeat(1, 3, 1, 1)
            
            self.fid.update(imgs_real, real=True)
            seed_b = seed_base + bidx if seed is not None else None
            imgs_pil = model.sample(
                batch_size=labels_real.numel(),
                labels=labels_real.tolist(),             # <- match labels
                num_inference_steps=num_inference_steps,
                device=dev,
                seed=seed_b,
            )
            imgs_fake = ((imgs_pil + 1.0) * 127.5).to(torch.uint8)  # [0,255] uint8
            if imgs_fake.size(1) == 1:
                imgs_fake = imgs_fake.repeat(1, 3, 1, 1)
            self.fid.update(imgs_fake, real=False)

            bidx += 1
            seen += imgs_real.size(0)

        return float(self.fid.compute().cpu().item())
    

@torch.no_grad() # return a frozen teacher model
def freeze_model(model):
    teacher = deepcopy(model)
    for param in teacher.parameters():
        param.requires_grad = False
    teacher.eval()
    return teacher