import os, re, h5py, numpy as np, torch
from typing import List, Tuple, Optional
from torch.utils.data import Dataset

def quantize_argmax(x):
    
    x_list = []
    x_clone = x.clone()
    max_indices = torch.argmax(x_clone, dim=-1, keepdim=True)
    result = torch.zeros_like(x_clone)
    result.scatter_(-1, max_indices, 1.0)
    for i in range(x_clone.shape[0]):
        x_tmp = result[i]
        x_tmp = x_tmp.cpu().numpy()
        x_list.append(x_tmp)

    return x_list


class MultiHDF5ArchitectureWeightDataset(Dataset):
    def __init__(
        self,
        file_path: str,
        architecture_max_layer: int = 6,
        weight_max_size: int = 64,
        patch_size: int = 8,
        with_bias: bool = True,
    ):
        super().__init__()

        self.weight_max_size = weight_max_size
        self.patch_size = patch_size
        self.with_bias = with_bias 
        
        
        self.h5_path = os.path.expanduser(file_path)
        assert os.path.exists(self.h5_path), f"H5 file not found: {file_path}"

        self.entries: List[Tuple[str,str]] = []
        self.data = {}
        self.arch_counts = {}
        self.layer_counts = {}

        skipped_count = 0
        total_arch_count = 0

        with h5py.File(self.h5_path, "r") as f:
            # Read weight_scale from metadata (required)
            self.weight_scale = float(f['metadata'].attrs['weight_scale'])

            for arch in f["architectures"].keys():
                total_arch_count += 1

                arch_layers = list(map(int, arch.split('_')))

                self.arch_counts[arch] = 0
                for seed in f[f"architectures/{arch}"].keys():

                    seed_grp = f[f"architectures/{arch}/{seed}"]

                    # Directly load from seed group (no checkpoint subdirectories)
                    self.entries.append((arch, seed))
                    self.arch_counts[arch] += 1
                    entry_key = f"{arch}_{seed}"

                    # Calculate layer count for this architecture
                    layer_count = len(arch_layers)
                    layer_key = f"{layer_count}L"
                    if layer_key not in self.layer_counts:
                        self.layer_counts[layer_key] = 0
                    self.layer_counts[layer_key] += 1

                    self.data[entry_key] = {
                        "weights": {k: torch.tensor(seed_grp["weights"][k][:]) for k in seed_grp["weights"].keys()},
                        "arch": arch_layers,
                        "layer_count": layer_count,
                        "reward": float(seed_grp.attrs['mean_return'])
                    }

        self.ops_decoder = ['input','output','16','32','64']
        self.architecture_max_layer = architecture_max_layer

        max_hidden_layers = max(len(arch_name.split('_')) for arch_name in self.arch_counts.keys())

        # Print dataset info
        print(f"Loaded {len(self.entries)} entries from {self.h5_path}")
        print(f"  Weight scale: {self.weight_scale:.4f} (target std: 0.538)")

    def __len__(self):
        return len(self.entries)

    def extract_weights_and_biases(self, idx: int) -> torch.Tensor:
        """Extract and process weights and biases for a specific entry."""
        arch_name, seed_name = self.entries[idx]
        entry_key = f"{arch_name}_{seed_name}"
        wdict = self.data[entry_key]["weights"]

        mats = []
        all_weight_keys = [k for k in wdict if "weight" in k]
        # Match both 'classifier.0.weight' and 'classifier.classifier.0.weight'
        classifier_pattern = re.compile(r'(?:classifier\.)?classifier\.(\d+)\.weight')
        indexed_keys = []
        
        for key in all_weight_keys:
            match = classifier_pattern.match(key)
            if match:
                idx_match = int(match.group(1))
                if idx_match % 2 == 0:
                    indexed_keys.append((idx_match//2, key)) 
                    
        #this required to keep the orders of the layers correct
        indexed_keys.sort()
        
        for norm_idx, key in indexed_keys:
            if norm_idx >= self.architecture_max_layer - 1:
                break

            w = wdict[key]
            bias_key = key.replace("weight", "bias")
            if bias_key in wdict and self.with_bias:
                b = wdict[bias_key]
                if b.ndim == 1:
                    b = b.unsqueeze(1)
                raw = torch.cat([w, b], dim=1)
            else:
                raw = w

            H = self.weight_max_size
            W = self.weight_max_size + self.patch_size
            P = torch.zeros(H, W, dtype=raw.dtype)
            h, c = raw.shape
            P[:h, :c] = raw  
            mats.append(P.unsqueeze(0).unsqueeze(0))

        results = torch.cat(mats, dim=0)

        # Apply weight scaling (normalize to target std ~0.538)
        if self.weight_scale != 1.0:
            results = results * self.weight_scale

        return results

    def extract_architectures(self, idx: int) -> torch.Tensor:
        """Extract and process architecture for a specific entry (no padding tokens)."""
        arch_name, seed_name = self.entries[idx]
        entry_key = f"{arch_name}_{seed_name}"
        raw_arch = self.data[entry_key]["arch"]

        ops = ["input"] + [str(x) for x in raw_arch] + ["output"]
        actual_layers = len(ops)

        mapping = {op:i for i,op in enumerate(self.ops_decoder)}
        architecture_n_vocab = len(self.ops_decoder)
        arch_tensor = torch.zeros(self.architecture_max_layer, architecture_n_vocab)

        # Fill in one-hot encodings for actual layers
        for i, token in enumerate(ops):
            if i >= self.architecture_max_layer:
                break
            arch_tensor[i, mapping[token]] = 1.0 

        
        # arch_tensor = get_data_scaler()(arch_tensor)

        return arch_tensor


    
    def __getitem__(self, idx):
        arch_tensor = self.extract_architectures(idx)
        weight_tensor = self.extract_weights_and_biases(idx)
        return arch_tensor, weight_tensor
