import os
import re
import copy
import pickle
import torch
from typing import List, Dict, Any, Tuple, Optional

from distrl.models.diffusion_models.edm2.original.training.phema import solve_posthoc_coefficients

# Deduce nimg based on kimg (= nimg//1000)
def kimg_to_nimg(kimg):
    nimg = (kimg * 1000 + 999) // 1024 * 1024
    assert nimg // 1000 == kimg
    return nimg

# Construct the full path of a network pickle
def pkl_path(dir, prefix, nimg, std):
    name = prefix + f'-{nimg//1000:07d}-{std:.3f}.pkl'
    if dir is None:
        return None
    return os.path.join(dir, name)

# List input pickles for post-hoc EMA reconstruction
def list_input_pickles(
    in_dir,             # Directory containing the input pickles
    in_prefix   = None, # Input filename prefix. None = anything goes
    in_std      = None, # Relative standard deviations of the input pickles. None = anything goes
):
    """List input pickles for post-hoc EMA reconstruction.
    Returns a list of dicts with path, nimg, std."""
    if not os.path.isdir(in_dir):
        raise ValueError(f'Input directory does not exist: {in_dir}')
    in_std = set(in_std) if in_std is not None else None

    pkls = []
    with os.scandir(in_dir) as it:
        for e in it:
            m = re.fullmatch(r'(.*)-(\d+)-(\d+\.\d+)\.pkl', e.name)
            if not m or not e.is_file():
                continue
            prefix = m.group(1)
            nimg = kimg_to_nimg(int(m.group(2)))
            std = float(m.group(3))
            if in_prefix is not None and prefix != in_prefix:
                continue
            if in_std is not None and std not in in_std:
                continue
            pkls.append({"path": e.path, "nimg": nimg, "std": std})
    pkls = sorted(pkls, key=lambda pkl: (pkl["nimg"], pkl["std"]))
    return pkls

# Parse a comma separated list of relative standard deviations
def parse_std_list(s):
    """Parse a comma separated list of relative standard deviations.
    The special token '...' interpreted as an evenly spaced interval.
    Example: '0.01,0.02,...,0.05' returns [0.01, 0.02, 0.03, 0.04, 0.05]
    """
    if isinstance(s, list):
        return s

    # Parse raw values
    raw = []
    for v in s.split(','):
        if v == '...':
            raw.append(None)
        else:
            try:
                raw.append(float(v))
            except ValueError:
                raise ValueError(f"Invalid float value: {v}")

    # Fill in '...' tokens
    out = []
    for i, v in enumerate(raw):
        if v is not None:
            out.append(v)
            continue
        if i < 2 or raw[i - 2] is None or raw[i - 1] is None:
            raise ValueError("'...' must be preceded by at least two floats")
        if i + 1 >= len(raw) or raw[i + 1] is None:
            raise ValueError("'...' must be followed by at least one float")
        if raw[i - 2] == raw[i - 1]:
            raise ValueError("The floats preceding '...' must not be equal")

        step = raw[i - 1] - raw[i - 2]
        target = raw[i + 1]
        approx_num = (target - raw[i - 1]) / step
        num = round(approx_num)

        if num <= 0:
            raise ValueError("'...' must correspond to a non-empty interval")
        if abs(num - approx_num) > 1e-4:
            raise ValueError("'...' must correspond to an evenly spaced interval")

        for j in range(num):
            out.append(raw[i - 1] + step * (j + 1))

    # Validate
    out = sorted(set(out))
    if not all(0.000 < v < 0.289 for v in out):
        raise ValueError('Relative standard deviation must be positive and less than 0.289')
    return out

def reconstruct_ema_model(
    in_pkls: List[Dict[str, Any]],
    out_std: float,
    out_nimg: Optional[int] = None,
    device: str = 'cuda',
    dtype: torch.dtype = torch.float16
) -> torch.nn.Module:
    """Reconstruct an EMA model with the specified standard deviation.

    Args:
        in_pkls: List of input pickles, each a dict with path, nimg, std
        out_std: Desired relative standard deviation for reconstruction
        out_nimg: Training time of the snapshot to reconstruct. None = highest input time
        device: Device to place the model on

    Returns:
        Reconstructed model
    """
    # Validate input pickles
    if out_nimg is None:
        out_nimg = max((pkl["nimg"] for pkl in in_pkls), default=0)
    elif not any(out_nimg == pkl["nimg"] for pkl in in_pkls):
        raise ValueError('Reconstruction time must match one of the input pickles')

    in_pkls = [pkl for pkl in in_pkls if 0 < pkl["nimg"] <= out_nimg]
    if len(in_pkls) == 0:
        raise ValueError('No valid input pickles found')

    in_nimg = [pkl["nimg"] for pkl in in_pkls]
    in_std = [pkl["std"] for pkl in in_pkls]

    # Compute post-hoc coefficients for the single output std
    out_std_batch = [out_std]

    # Use the appropriate module depending on which one was imported
    coefs = solve_posthoc_coefficients(in_nimg, in_std, out_nimg, out_std_batch)

    # Initialize output model
    out_model = None
    out_pkl_data = None

    # Loop over input pickles
    for i in range(len(in_pkls)):
        # Load input pickle
        with open(in_pkls[i]["path"], 'rb') as f:
            in_pkl_data = pickle.load(f)
            in_net = in_pkl_data['ema'].to(torch.float32)

        # Initialize output model on first iteration
        if out_model is None:
            out_pkl_data = copy.deepcopy(in_pkl_data)
            out_model = out_pkl_data['ema']
            for pj in out_model.parameters():
                pj.zero_()

        # Accumulate weights
        for pi, pj in zip(in_net.parameters(), out_model.parameters()):
            pj += pi * coefs[i, 0]  # Only one output std

        # Copy buffers from the last input model
        if i == len(in_pkls) - 1:
            for pi, pj in zip(in_net.buffers(), out_model.buffers()):
                pj.copy_(pi)

    if out_model is None:
        raise ValueError("Failed to create output model - no valid input models found")

    # Convert to float16 and move to device
    out_model = out_model.to(dtype).to(device)

    return out_model

def parse_reconstruct_ema_config(config_str: str) -> Tuple[str, List[float]]:
    """Parse the DISTRL_DEBUG_RECONSTRUCT_EMA configuration string.

    Args:
        config_str: String in the format "<model pool path>::0.242,0.432,..."

    Returns:
        Tuple of (model_pool_path, std_values)
    """
    if '::' not in config_str:
        raise ValueError("Invalid DISTRL_DEBUG_RECONSTRUCT_EMA format. Expected '<model pool path>::0.242,0.432,...'")

    parts = config_str.split('::', 1)
    model_pool_path = parts[0].strip()
    std_values_str = parts[1].strip()

    std_values = parse_std_list(std_values_str)

    return model_pool_path, std_values
