import os
import numpy as np
import torch
import torch.nn as nn

from src.dataset import quantize_argmax


def get_input_output_dims_for_env(env_name):
    if env_name == 'PickCube-v1':
        return 42, 8
    elif env_name == 'PushCube-v1':
        return 35, 8
    elif env_name == 'StackCube-v1':
        return 48, 8
    elif env_name == 'StackCubeEasy-v1':
        return 48, 8
    elif env_name == 'AnymalC-Reach-v1':
        return 35, 12
    else:
        raise ValueError(f"Unknown environment: {env_name}")


def extract_architecture(env_name, architecture):
    arch_op = architecture.clone()
    # Inverse scale from [-1, 1] to [0, 1] before quantization
    # arch_op = quantize_argmax(get_data_inverse_scaler()(arch_op).unsqueeze(0))[-1]  # Disabled: using [0,1] one-hot directly
    arch_op = quantize_argmax(arch_op.unsqueeze(0))[-1]

    input_dim, output_dim = get_input_output_dims_for_env(env_name)
    ops_decoder = ['input', 'output', '16', '32', '64']

    raw_tokens = []
    for node in arch_op:
        token_idx = np.argmax(node)
        raw_tokens.append(ops_decoder[token_idx])
    print(f"Raw tokens: {raw_tokens}")

    arch_ops = []
    for node in arch_op:
        node_type = ops_decoder[np.argmax(node)]
        if node_type == 'input':
            arch_ops.append(input_dim)
        elif node_type == 'output':
            arch_ops.append(output_dim)
            break
        else:
            arch_ops.append(int(node_type))

    print("testing architecture", arch_ops)

    if len(arch_ops) < 2:
        print(f"WARNING: Invalid architecture {arch_ops}, skipping...")
        return None

    return arch_ops


def extract_weights(arch_ops, weights, weight_scale=1.0):
    """Extract weights from diffusion output.

    Args:
        arch_ops: List of layer dimensions [input_dim, hidden1, ..., output_dim]
        weights: Tensor of shape (num_layers, 1, H, W) with weights
        weight_scale: Scale factor to divide by (inverse of training scale)
    """
    policy_weights = {}
    for j in range(len(arch_ops) - 1):
        in_dim = int(arch_ops[j])
        out_dim = int(arch_ops[j + 1])
        weight_matrix = weights[j, 0, :out_dim, :in_dim+1].cpu().numpy()
        # Inverse scale to get raw weights (undo training normalization)
        weight_matrix = weight_matrix / weight_scale
        w = weight_matrix[:, :-1]
        b = weight_matrix[:, -1]
        policy_weights[f'layer_{j}'] = {
            'in_dim': in_dim,
            'out_dim': out_dim,
            'weights': w,
            'bias': b
        }
    return policy_weights


def convert_policy_weights_to_mlp(policy_weights):
    sorted_keys = sorted(policy_weights.keys(), key=lambda x: int(x.split('_')[1]))

    layer_sizes = [policy_weights[sorted_keys[0]]['in_dim']]
    for key in sorted_keys:
        layer_sizes.append(policy_weights[key]['out_dim'])

    layers = []
    for i, key in enumerate(sorted_keys):
        linear = nn.Linear(policy_weights[key]['in_dim'], policy_weights[key]['out_dim'])
        linear.weight.data = torch.tensor(policy_weights[key]['weights'], dtype=torch.float32)
        linear.bias.data = torch.tensor(policy_weights[key]['bias'], dtype=torch.float32)
        layers.append(linear)

        if i < len(sorted_keys) - 1:
            layers.append(nn.ReLU())  # Match GHN training activation

    model = nn.Sequential(*layers)
    return model, layer_sizes


def generate_policy(env_name, architecture, weights, save_dir, weight_scale=1.0):
    num_samples, _, _ = architecture.shape
    assert architecture.shape[1] == weights.shape[1] + 1

    arch_counters = {}

    for i in range(num_samples):
        print(f"Generating policy {i}...")
        current_arch = extract_architecture(env_name, architecture[i])
        if current_arch is None:
            print(f"Skipping policy {i} due to invalid architecture")
            continue

        policy_weights = extract_weights(current_arch, weights[i], weight_scale)
        mlp, layer_sizes = convert_policy_weights_to_mlp(policy_weights)

        arch_name = '_'.join(str(x) for x in layer_sizes[1:-1])
        arch_dir = os.path.join(save_dir, f'policy_{arch_name}')
        os.makedirs(arch_dir, exist_ok=True)

        if arch_name not in arch_counters:
            arch_counters[arch_name] = 0
        policy_idx = arch_counters[arch_name]
        arch_counters[arch_name] += 1

        policy_path = os.path.join(arch_dir, f'policy_{policy_idx}.pt')
        torch.save(mlp, policy_path)
        print(f"  Saved: {policy_path} (architecture: {layer_sizes})")

    print("Finished generating policies!")
