import os
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as T
import torch.nn.functional as F
from fit_classifiers import build_model as build_clf_model


class WithIndex(torch.utils.data.Dataset):
    def __init__(self, base_ds):
        self.base_ds = base_ds

    def __len__(self):
        return len(self.base_ds)

    def __getitem__(self, idx: int):
        item = self.base_ds[idx]
        if isinstance(item, dict):
            item = dict(item)
            item["idx"] = idx
            return item
        elif isinstance(item, (list, tuple)):
            return (*item, idx)
        else:
            return (item, None, idx)



def get_dataset(name, root="./dataset", train=False, resize=False):
    num_classes = None
    input_shape = None

    name = name.lower()
    if name == "cifar10":
        mean, std = (0.4914,0.4822,0.4465), (0.2023,0.1994,0.2010)
        
        if resize:
            tf = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean,std)])
            input_shape = (3,224,224)
        else:
            tf = T.Compose([T.ToTensor(), T.Normalize(mean,std)])
            input_shape = (3,32,32)
    
        num_classes = 10
        ds = WithIndex(torchvision.datasets.CIFAR10(root=root, train=train, download=True, transform=tf))    

    elif name == "cifar100":
        mean, std = (0.5071,0.4865,0.4409), (0.2673,0.2564,0.2762)
        if resize:
            tf = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean,std)])
            input_shape = (3,224,224)
        else:
            tf = T.Compose([T.ToTensor(), T.Normalize(mean,std)])
            input_shape = (3,32,32)

        num_classes = 100
        ds = WithIndex(torchvision.datasets.CIFAR100(root=root, train=train, download=True, transform=tf))

    elif name == "mnist":
        mean, std = (0.1307,), (0.3081,)
        if resize:
            tf = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean,std)])
            input_shape = (1,224,224)
        else:
            tf = T.Compose([T.ToTensor(), T.Normalize(mean,std)])
            input_shape = (1,28,28)

        num_classes = 10
        ds = WithIndex(torchvision.datasets.MNIST(root=root, train=train, download=True, transform=tf))

    elif name == "tinyimagenet":
        mean, std = (0.4802,0.4481,0.3975), (0.2302,0.2265,0.2262)
        if resize:
            tf = T.Compose([T.Resize(224),T.CenterCrop(224),T.ToTensor(),T.Normalize(mean,std)])
            input_shape = (3,224,224)
        else:
            tf = T.Compose([T.ToTensor(),T.Normalize(mean,std)])
            input_shape = (3,64,64)

        num_classes = 200
        ds = WithIndex(torchvision.datasets.ImageFolder(os.path.join(root,"tiny-imagenet-200","val"), transform=tf))

    else:
        raise ValueError(f"Unknown dataset {name}")

    return ds, num_classes, input_shape



def parse_batch_spec(spec):
    if spec is None or str(spec).strip() == "":
        return None
    out = set()
    for part in str(spec).split(","):
        part = part.strip()
        if not part:
            continue
        if "-" in part:
            a, b = part.split("-")
            a, b = int(a), int(b)
            if a > b: a, b = b, a
            out.update(range(a, b + 1))
        else:
            out.add(int(part))
    return set(sorted(out))


def build_model(arch: str, num_classes: int, device):

    arch = arch.lower()
    if arch == "resnet18":
        model = torchvision.models.resnet18(weights=None, num_classes=num_classes)
        feat_extractor = nn.Sequential(*list(model.children())[:-1], nn.Flatten())

    elif arch == "resnet50":
        model = build_clf_model(arch="resnet50", num_classes=num_classes, device=device, pretrained=False)
        feat_extractor = nn.Sequential(*list(model.children())[:-1], nn.Flatten())

    elif arch == "wide_resnet50_2":
        model = build_clf_model(arch="wide_resnet50_2", num_classes=num_classes, device=device, pretrained=False)
        feat_extractor = nn.Sequential(*list(model.children())[:-1], nn.Flatten())

    elif arch == "vgg16":
        model = build_clf_model(arch="vgg16", num_classes=num_classes, device=device, pretrained=False)
        feat_extractor = nn.Sequential(model.features, model.avgpool, nn.Flatten())

    else:
        raise ValueError(f"Unsupported arch: {arch}")

    return model, feat_extractor


@torch.no_grad()
def infer_feat_dim(fe: nn.Module, img_shape):
    C,H,W = img_shape
    dummy = torch.zeros(1, C, H, W, device=next(fe.parameters()).device)

    return fe(dummy).shape[-1]


@torch.no_grad()
def eval_acc(model, dataset, device):

    loader = torch.utils.data.DataLoader(dataset, batch_size=512, shuffle=False, num_workers=2, pin_memory=True)
    model.eval()

    correct = total = 0
    for x, y, _ in loader:
        x, y = x.to(device), y.to(device)
        pred = model(x).argmax(1)
        correct += (pred == y).sum().item() 
        total += y.numel()
    acc = correct / total
    print(f"[clf] accuracy={acc * 100:.2f}%"); return acc



def g_ball(u, gamma, norm_type):
    g = None

    if norm_type == "linf":
        g = gamma * u.tanh()

    elif norm_type == "l2":

        flat = u.view(u.size(0), -1)
        norm = flat.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6)
        g = (gamma * flat / norm).view_as(u)

    if g is None:
        raise ValueError(f"not supported norm_type: {norm_type}")

    return g

def compute_pr_on_clean_correct(
    model, gmm, loader, out_shape, device,
    num_samples=100, batch_indices=None,
    temperature=1.0,
    use_soft_sampling=False,
    chunk_size=None
):
    total_used = 0
    pr_sum = 0.0
    clean_correct = 0
    total_seen = 0

    with torch.no_grad():
        for it, (x, y, _) in enumerate(loader):
            if (batch_indices is not None) and (it not in batch_indices):
                continue

            x, y = x.to(device), y.to(device)
            B = x.size(0)

            logits_clean = model(x)
            pred_clean = logits_clean.argmax(1)
            mask = (pred_clean == y)

            clean_correct += mask.sum().item()
            total_seen += B

            if mask.sum().item() == 0:
                continue

            x_sel = x[mask]
            y_sel = y[mask]
            n = x_sel.size(0)

            per_image_pr = gmm.evaluate_pr(
                x_sel, y_sel, model,
                num_samples=num_samples,
                use_soft_sampling=use_soft_sampling,
                temperature=temperature,
                reduction='none',
                chunk_size=chunk_size
            )

            pr_sum += per_image_pr.sum().item()
            total_used += n

    pr = pr_sum / max(1, total_used)
    clean_acc = clean_correct / max(1, total_seen)

    method_name = "soft (Gumbel-Softmax)" if use_soft_sampling else "hard (categorical)"
    chunk_info = f", chunk_size={chunk_size}" if chunk_size is not None else ""
    print(f"[PR@clean] used={total_used} / seen={total_seen} "
          f"(clean acc={clean_acc*100:.2f}%), num_samples={num_samples}{chunk_info} → PR={pr:.4f} [method: {method_name}]")

    return pr, total_used, clean_acc


def slug_gamma(g):
    return f"{g:.4f}".replace('.', 'p')


def initialize_gmm_parameters(gmm, init_mode='spread'): 
    with torch.no_grad():
        if hasattr(gmm, 'mu') and isinstance(gmm.mu, nn.Parameter):
            K, D = gmm.mu.shape
            
            if init_mode == 'spread':
                gmm.mu.data.normal_(0, 0.5)
                if K <= 8 and D >= 3:
                    for k in range(min(K, 8)):
                        binary = format(k, '03b')
                        for d, bit in enumerate(binary):
                            if d < D:
                                gmm.mu.data[k, d] = 1.0 if bit == '1' else -1.0
            
            elif init_mode == 'random':
                gmm.mu.data.normal_(0, 1.0)
            
            elif init_mode == 'grid':
                if D >= 2:
                    side = int(np.ceil(K ** (1/2)))
                    for k in range(K):
                        i, j = k // side, k % side
                        gmm.mu.data[k, 0] = (i / side) * 2 - 1
                        gmm.mu.data[k, 1] = (j / side) * 2 - 1
            
            elif init_mode == 'uniform':
                gmm.mu.data.uniform_(-1, 1)

    print(f"[init] GMM means initialized with mode='{init_mode}'")

class TemperatureScheduler:
    def __init__(self, gmm, initial_T_pi=1.0, final_T_pi=1.0,
                 initial_T_mu=1.0, final_T_mu=1.0,
                 initial_T_sigma=1.0, final_T_sigma=1.0,
                 initial_T_shared=1.0, final_T_shared=1.0,
                 warmup_epochs=50):
        self.gmm = gmm
        self.initial_T_pi = initial_T_pi
        self.final_T_pi = final_T_pi

        self.initial_T_mu = initial_T_mu
        self.final_T_mu = final_T_mu

        self.initial_T_sigma = initial_T_sigma
        self.final_T_sigma = final_T_sigma

        self.initial_T_shared = initial_T_shared
        self.final_T_shared = final_T_shared

        self.warmup_epochs = warmup_epochs

    def step(self, epoch):
        if epoch < self.warmup_epochs:
            alpha = epoch / self.warmup_epochs

            T_pi = self.initial_T_pi + alpha * (self.final_T_pi - self.initial_T_pi)
            T_mu = self.initial_T_mu + alpha * (self.final_T_mu - self.initial_T_mu)
            T_sigma = self.initial_T_sigma + alpha * (self.final_T_sigma - self.initial_T_sigma)
            T_shared = self.initial_T_shared + alpha * (self.final_T_shared - self.initial_T_shared)
        else:

            T_pi = self.final_T_pi
            T_mu = self.final_T_mu
            T_sigma = self.final_T_sigma
            T_shared = self.final_T_shared

        self.gmm.set_temperatures(T_pi=T_pi, T_mu=T_mu, T_sigma=T_sigma, T_shared=T_shared)
        return T_pi, T_mu, T_sigma, T_shared


@torch.no_grad()
def check_mode_collapse(gmm, loader, device, num_batches=10):
    gmm.eval()
    try:
        pi_distributions = []

        for i, (x, y, _) in enumerate(loader):
            if i >= num_batches:
                break
            x, y = x.to(device), y.to(device)

            out = gmm.forward(x=x, y=y)
            pi_logits = out['cache']['pi_logits']
            pi_probs = F.softmax(pi_logits, dim=-1)
            pi_distributions.append(pi_probs.cpu())

        all_pi = torch.cat(pi_distributions, dim=0)
        mean_pi = all_pi.mean(dim=0)
        max_pi = mean_pi.max().item()
        min_pi = mean_pi.min().item()
        std_pi = mean_pi.std().item()
        entropy = -(mean_pi * torch.log(mean_pi + 1e-8)).sum().item()
        max_entropy = np.log(gmm.K)

        print(f"\n{'='*60}")
        print(f"MODE COLLAPSE CHECK (K={gmm.K})")
        print(f"{'='*60}")
        print(f"Average π per component: {mean_pi.numpy()}")
        print(f"Max π: {max_pi:.4f} | Min π: {min_pi:.4f} | Std: {std_pi:.4f}")
        print(f"Entropy: {entropy:.4f} / {max_entropy:.4f} ({entropy/max_entropy*100:.1f}%)")

        if max_pi > 0.5:
            print(f"WARNING: Potential mode collapse!")
        elif std_pi > 0.15:
            print(f"WARNING: High variance in usage")
        else:
            print(f"Component usage looks balanced")
        print(f"{'='*60}\n")

        return {
            'mean_pi': mean_pi.numpy(),
            'max_pi': max_pi,
            'min_pi': min_pi,
            'std_pi': std_pi,
            'entropy': entropy,
            'entropy_ratio': entropy / max_entropy
        }
    finally:
        gmm.train()


def build_decoder_from_flag(backend: str, latent_dim: int, out_shape: tuple, device):
    C, H, W = out_shape
    
    def calc_init_size(target_size):
        if target_size <= 32:
            return 4  
        elif target_size <= 64:
            return 7  
        else:
            return target_size // 32  
    
    if backend == "bicubic":
        class BicubicDecoder(nn.Module):
            def __init__(self):
                super().__init__()
                self.init_size = calc_init_size(min(H, W))
                self.init_dim = C * self.init_size * self.init_size
            
            def forward(self, z):
                B = z.size(0)
                assert z.size(1) == self.init_dim, f"Expected latent_dim={self.init_dim}, got {z.size(1)}"

                z = z.view(B, C, self.init_size, self.init_size)
                return F.interpolate(z, size=(H, W), mode='bicubic', align_corners=False)

        decoder = BicubicDecoder().to(device)
        print(f"[Decoder 'bicubic'] {sum(p.numel() for p in decoder.parameters()):,} params")

    elif backend == "bicubic_trainable":
        class BicubicDecoder(nn.Module):
            def __init__(self):
                super().__init__()
                self.init_size = calc_init_size(min(H, W))
                init_dim = C * self.init_size * self.init_size
                self.latent_to_spatial = nn.Linear(latent_dim, init_dim)
            
            def forward(self, z):
                B = z.size(0)
                h = self.latent_to_spatial(z)
                h = h.view(B, C, self.init_size, self.init_size)
                return F.interpolate(h, size=(H, W), mode='bicubic', align_corners=False)
        
        decoder = BicubicDecoder().to(device)
        print(f"[Decoder 'bicubic'] {sum(p.numel() for p in decoder.parameters()):,} params")

    else:
        raise ValueError(
            f"Unknown decoder backend: '{backend}'. Choose from:\n"
            f"  Frozen: wavelet, dct, nearest_blur\n"
            f"  Trainable: conv, upsample, tiny, mlp"
        )
    
    return decoder