"""
Shared utilities.
"""

import os
import shutil
import subprocess
import tempfile
import urllib.parse
import urllib.request
import logging
import tarfile
from typing import List, Tuple, Any, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F


def resolve_remote_path(path, suffix='.pt', desc='file'):
    if not path:
        raise FileNotFoundError(f'Empty {desc} path')
    expanded = os.path.expandvars(os.path.expanduser(path))
    # Local
    if os.path.isfile(expanded):
        return expanded, None
    parsed = urllib.parse.urlparse(expanded)
    # file://
    if parsed.scheme == 'file':
        local = urllib.request.url2pathname(parsed.path)
        local = os.path.abspath(local)
        if os.path.isfile(local):
            return local, None
        raise FileNotFoundError(f'file URL not found: {path}')
    # http(s)://
    if parsed.scheme in ('http', 'https'):
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
        tmp_path = tmp.name
        tmp.close()
        try:
            with urllib.request.urlopen(expanded) as resp, open(tmp_path, 'wb') as f:
                shutil.copyfileobj(resp, f)
        except Exception as e:
            try:
                os.unlink(tmp_path)
            except Exception:
                pass
            raise RuntimeError(f'Failed to download {desc} from {expanded}: {e}')
        logging.info('Downloaded %s from %s to %s', desc, expanded, tmp_path)
        return tmp_path, lambda: (os.unlink(tmp_path) if os.path.exists(tmp_path) else None)
    # SCP-like user@host:/abs/path
    if (':' in expanded) and not (len(expanded) >= 2 and expanded[1] == ':' and expanded[0].isalpha()):
        if shutil.which('scp') is None:
            raise RuntimeError('scp not available; install openssh-client or copy the file locally.')
        base = os.path.basename(expanded.split(':', 1)[1])
        suff = os.path.splitext(base)[1] or suffix
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suff)
        tmp_path = tmp.name
        tmp.close()
        try:
            subprocess.run(['scp', '-q', expanded, tmp_path], check=True)
        except subprocess.CalledProcessError as e:
            try:
                os.unlink(tmp_path)
            except Exception:
                pass
            raise RuntimeError(f'Failed to fetch remote {desc} via scp from {expanded} (exit {e.returncode}).')
        logging.info('Fetched %s via scp from %s to %s', desc, expanded, tmp_path)
        return tmp_path, lambda: (os.unlink(tmp_path) if os.path.exists(tmp_path) else None)
    raise FileNotFoundError(f'{desc.capitalize()} not found: {path}')



def extract_state_dict(ckpt):
    if isinstance(ckpt, dict):
        for key in ('model_state', 'state_dict', 'model_state_dict', 'model'):
            if key in ckpt and isinstance(ckpt[key], dict):
                return ckpt[key]
        for v in ckpt.values():
            if isinstance(v, dict) and any(hasattr(t, 'dim') for t in v.values()):
                return v
        return ckpt
    return ckpt


def infer_arch_from_state_dict(sd):
    keys = set(sd.keys())
    if any(k.startswith('patch_proj.') or k.startswith('encoder.') or k == 'pos_emb' for k in keys):
        return 'transformer'
    if 'player_emb.weight' in keys:
        return 'conv_player'
    if any(k.startswith('in_conv.') or k.startswith('blocks.') or k.startswith('policy_conv.') for k in keys):
        return 'resnet'
    if any(k.startswith('trunk.') for k in keys):
        return 'conv'
    linear_heads = {'policy_head.weight', 'policy_head.bias', 'value_head.weight', 'value_head.bias'}
    if linear_heads.issubset(keys):
        return 'logistic'
    return 'conv'


def infer_in_channels(arch, sd):
    try:
        if arch == 'resnet' and 'in_conv.weight' in sd:
            return int(sd['in_conv.weight'].shape[1])
        if arch in ('conv', 'conv_player'):
            if 'trunk.0.weight' in sd:
                return int(sd['trunk.0.weight'].shape[1])
            for k, v in sd.items():
                if k.startswith('trunk.') and k.endswith('.weight') and hasattr(v, 'dim') and v.dim() == 4:
                    return int(v.shape[1])
        if arch == 'transformer':
            if 'patch_proj.weight' in sd:
                return int(sd['patch_proj.weight'].shape[1])
            if 'proj.weight' in sd:
                return int(sd['proj.weight'].shape[1])
        for k, v in sd.items():
            if k.endswith('.weight') and hasattr(v, 'dim') and v.dim() == 4:
                return int(v.shape[1])
    except Exception:
        return None
    return None


def infer_d_model(arch, sd, default=128):
    try:
        if arch == 'resnet':
            if 'in_conv.weight' in sd:
                return int(sd['in_conv.weight'].shape[0])
            if 'policy_conv.weight' in sd:
                return int(sd['policy_conv.weight'].shape[1])
            if 'value_conv.weight' in sd:
                return int(sd['value_conv.weight'].shape[1])
        if arch in ('conv', 'conv_player'):
            if 'trunk.0.weight' in sd:
                return int(sd['trunk.0.weight'].shape[0])
            for k, v in sd.items():
                if k.startswith('trunk.') and k.endswith('.weight') and hasattr(v, 'dim') and v.dim() == 4:
                    return int(v.shape[0])
    except Exception:
        return default
    return default


def infer_transformer_hparams(sd, default_d_model=128):
    try:
        if 'patch_proj.weight' in sd:
            d_model = int(sd['patch_proj.weight'].shape[0])
        elif 'proj.weight' in sd:
            d_model = int(sd['proj.weight'].shape[0])
        else:
            d_model = default_d_model
    except Exception:
        d_model = default_d_model
    layer_idxs = []
    for k in sd.keys():
        if k.startswith('encoder.layers.'):
            parts = k.split('.')
            if len(parts) > 2 and parts[2].isdigit():
                layer_idxs.append(int(parts[2]))
    num_layers = (max(layer_idxs) + 1) if layer_idxs else 2
    if 'encoder.layers.0.linear1.weight' in sd:
        dim_ff = int(sd['encoder.layers.0.linear1.weight'].shape[0])
    else:
        dim_ff = 4 * d_model
    return d_model, num_layers, dim_ff


def guess_hw_from_state(arr):
    import numpy as np
    if isinstance(arr, torch.Tensor):
        arr = arr.detach().cpu().numpy()
    if arr.ndim == 2:
        return int(arr.shape[0]), int(arr.shape[1])
    if arr.ndim == 3:
        if arr.shape[0] <= 8:
            return int(arr.shape[1]), int(arr.shape[2])
        if arr.shape[2] <= 8:
            return int(arr.shape[0]), int(arr.shape[1])
        return int(arr.shape[1]), int(arr.shape[2])
    if arr.ndim == 4:
        return guess_hw_from_state(arr[0])
    raise ValueError(f'Unsupported state shape for inferring HxW: {arr.shape}')


def infer_alphazero_from_state_dict(sd):
    """Infer AlphaZero-style model params from a state_dict.

    Returns a dict with possible keys: embed_dim, in_channels, action_size, channels.
    Works with TransformerAlphaNet (proj/pos_embed/policy_head) and AlphaNet (conv1/policy_fc/...)
    """
    params = {}
    # TransformerAlphaNet
    if 'proj.weight' in sd and hasattr(sd['proj.weight'], 'shape'):
        params['embed_dim'] = int(sd['proj.weight'].shape[0])
        params['in_channels'] = int(sd['proj.weight'].shape[1])
    if 'policy_head.weight' in sd and hasattr(sd['policy_head.weight'], 'shape'):
        params['action_size'] = int(sd['policy_head.weight'].shape[0])
    # AlphaNet (CNN)
    if 'conv1.weight' in sd and hasattr(sd['conv1.weight'], 'shape'):
        params['channels'] = int(sd['conv1.weight'].shape[0])
        params['in_channels'] = int(sd['conv1.weight'].shape[1])
    if 'policy_fc.weight' in sd and hasattr(sd['policy_fc.weight'], 'shape'):
        params['action_size'] = int(sd['policy_fc.weight'].shape[0])
    return params
