import os
import json
import random
import torch
import numpy as np
import yaml
import argparse
import h5py

from src.diffusion import create_diffusion
from src.model import NNiT
from src.sample_util import (
    extract_architecture,
    extract_weights,
    convert_policy_weights_to_mlp,
)

from src.eval_util import evaluate_policy_with_video
import src.envs.custom_envs  # Register custom environments



def load_checkpoint(checkpoint_path):
    """Load model checkpoint, supporting both EMA and direct state dict formats."""
    checkpoint = torch.load(checkpoint_path, map_location=lambda storage, _: storage)

    if "ema" in checkpoint:
        return checkpoint["ema"]
    elif "model" in checkpoint:
        return checkpoint["model"]
    return checkpoint


def load_architectures(arch_file, num_samples=None, seed=None):
    """Load architectures from JSON file."""
    with open(arch_file, 'r') as f:
        all_archs = json.load(f)

    # Random sampling for train mode
    if num_samples is not None and num_samples < len(all_archs):
        if seed is not None:
            np.random.seed(seed)
        indices = np.random.choice(len(all_archs), num_samples, replace=False)
        return [all_archs[i] for i in sorted(indices)]

    return all_archs


def print_arch_stats(architectures):
    """Print statistics about the architecture distribution."""
    stats = {}
    for arch in architectures:
        num_hidden = len(arch)
        stats[num_hidden] = stats.get(num_hidden, 0) + 1

    print(f"Architectures: {len(architectures)} total")
    for num_hidden in sorted(stats.keys()):
        count = stats[num_hidden]
        print(f"  {num_hidden} hidden: {count} ({count/len(architectures)*100:.0f}%)")


def encode_architectures(archs, max_layers, n_vocab, device):
    """Encode architectures as one-hot tensors scaled to [-1, 1]."""
    ops_vocab = ['input', 'output', '16', '32', '64']
    ops_map = {op: i for i, op in enumerate(ops_vocab)}

    arch_tensors = []
    for hidden_layers in archs:
        tokens = ['input'] + [str(s) for s in hidden_layers] + ['output']

        arch_one_hot = torch.zeros(max_layers, n_vocab, device=device)
        for i, token in enumerate(tokens):
            arch_one_hot[i, ops_map[token]] = 1.0

        # Scale from [0, 1] to [-1, 1] for diffusion
        # arch_one_hot = get_data_scaler()(arch_one_hot)  # Disabled: using [0,1] one-hot directly

        arch_tensors.append(arch_one_hot)

    return torch.stack(arch_tensors)


def build_metrics_json(config, mode, all_results):
    """Build metrics JSON structure with summary, top 10 summary, top 10 individual, and all policies."""
    # Sort by mean_return descending
    sorted_results = sorted(all_results, key=lambda x: x['mean_return'], reverse=True)
    top_10_results = sorted_results[:10]

    metrics = {
        'config': {
            'mode': mode,
            'env_id': config['data']['env_name'],
            'num_samples': len(all_results),
            'num_sampling_steps': config['sample']['num_sampling_steps'],
            'eval_steps': config['sample']['eval_steps'],
            'seed': config['sample']['seed'],
        },
        'summary': {
            'total_policies': len(all_results),
            'mean_return': float(np.mean([r['mean_return'] for r in all_results])),
            'std_return': float(np.std([r['mean_return'] for r in all_results])),
            'min_return': float(np.min([r['mean_return'] for r in all_results])),
            'max_return': float(np.max([r['mean_return'] for r in all_results])),
            'mean_success_once': float(np.mean([r['success_once'] for r in all_results])),
            'mean_success_at_end': float(np.mean([r['success_at_end'] for r in all_results])),
        },
        'top_10_summary': {
            'mean_return': float(np.mean([r['mean_return'] for r in top_10_results])),
            'std_return': float(np.std([r['mean_return'] for r in top_10_results])),
            'min_return': float(np.min([r['mean_return'] for r in top_10_results])),
            'max_return': float(np.max([r['mean_return'] for r in top_10_results])),
            'mean_success_once': float(np.mean([r['success_once'] for r in top_10_results])),
            'mean_success_at_end': float(np.mean([r['success_at_end'] for r in top_10_results])),
        },
        'top_10': top_10_results,
        'all_policies': all_results,
    }

    return metrics


def set_seed(seed):
    """Set seed for reproducibility (FiT-style)."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def main(config):
    """Main function for sampling and evaluating policies."""
    # Setup seeding for reproducibility (FiT-style)
    seed = config['sample']['seed']
    set_seed(seed)
    torch.set_grad_enabled(False)
    device = config['sample'].get('device', 'cuda' if torch.cuda.is_available() else 'cpu')

    # Detect mode: joint or a2w (architecture-to-weights)
    is_a2w_mode = 'a2w_architecture_json' in config['sample']
    mode = "a2w" if is_a2w_mode else "joint"

    print(f"Sampling Mode: {mode} | Device: {device} | Seed: {config['sample']['seed']}")

    # Load weight_scale from dataset metadata
    with h5py.File(config['data']['path'], 'r') as f:
        weight_scale = float(f['metadata'].attrs['weight_scale'])
    print(f"Weight scale: {weight_scale:.4f} (will divide sampled weights by this)")

    # Build model
    model = NNiT(
        architecture_max_layer=config['diffusion']['architecture_max_layer'],
        architecture_n_vocab=config['diffusion']['architecture_n_vocab'],
        weight_max_size=config['diffusion']['weight_max_size'],
        patch_size=config['diffusion']['patch_size'],
        hidden_size=config['diffusion']['hidden_size'],
        depth=config['diffusion']['depth'],
        num_heads=config['diffusion']['num_heads'],
        mlp_ratio=config['diffusion']['mlp_ratio'],
        learn_sigma=config['diffusion']['learn_sigma'],
        use_swiglu=config['diffusion']['use_swiglu'],
        use_swiglu_large=config['diffusion']['use_swiglu_large'],
    ).to(device)

    # Load checkpoint
    print(f"Loading: {config['sample']['checkpoint']}")
    state_dict = load_checkpoint(config['sample']['checkpoint'])
    model.load_state_dict(state_dict)
    model.eval()

    # Create diffusion process
    diffusion = create_diffusion(
        timestep_respacing=str(config['sample']['num_sampling_steps']),
        learn_sigma=config['diffusion']['learn_sigma']
    )

    # Get dimensions
    arch_max_layer = config['diffusion']['architecture_max_layer']
    arch_n_vocab = config['diffusion']['architecture_n_vocab']
    weight_max_size = config['diffusion']['weight_max_size']
    patch_size = config['diffusion']['patch_size']

    save_dir = config['sample']['save_dir']
    os.makedirs(save_dir, exist_ok=True)

    env_name = config['data']['env_name']
    eval_steps = config['sample']['eval_steps']

    all_results = []

    if is_a2w_mode:
        # === A2W Mode: Conditional sampling p(w|a) ===
        print(f"Loading: {config['sample']['a2w_architecture_json']}")
        num_samples = config['sample']['num_samples']
        architectures = load_architectures(
            config['sample']['a2w_architecture_json'],
            num_samples=num_samples,
            seed=config['sample']['seed']
        )
        print_arch_stats(architectures)

        # Encode architectures
        arch_tensors = encode_architectures(
            architectures,
            arch_max_layer,
            arch_n_vocab,
            device
        )

        # Repeat each architecture for num_weights_per_arch
        num_weights_per_arch = config['sample']['num_weights_per_arch']
        n_archs = arch_tensors.shape[0]
        arch_tensors_repeated = arch_tensors.repeat_interleave(num_weights_per_arch, dim=0)
        n_total = arch_tensors_repeated.shape[0]

        shape = {
            "architecture": arch_tensors_repeated.shape,
            "weight": (n_total, arch_max_layer - 1, 1, weight_max_size, weight_max_size + patch_size)
        }

        print(f"Sampling {num_weights_per_arch} weights per architecture ({n_archs} archs = {n_total} total, {config['sample']['num_sampling_steps']} steps)...")

        model_kwargs = dict()
        samples = diffusion.monl_conditional_p_sample_loop(
            model,
            arch_tensors_repeated,
            shape,
            clip_denoised=False,
            model_kwargs=model_kwargs,
            progress=True,
            device=device
        )

        all_arch_samples = samples["architecture"]
        all_weight_samples = samples["weight"]

        # Process each architecture and sample
        print(f"\nGenerating and evaluating policies...")
        for arch_idx in range(n_archs):
            arch_hidden_layers = architectures[arch_idx]
            arch_name = '_'.join(str(x) for x in arch_hidden_layers)
            arch_dir = os.path.join(save_dir, f'policy_{arch_name}')
            os.makedirs(arch_dir, exist_ok=True)

            for sample_idx in range(num_weights_per_arch):
                global_idx = arch_idx * num_weights_per_arch + sample_idx

                # Extract and convert to policy
                current_arch = extract_architecture(env_name, all_arch_samples[global_idx])
                if current_arch is None:
                    print(f"  Skipping policy_{arch_name}/policy_{sample_idx} - invalid architecture")
                    continue

                policy_weights = extract_weights(current_arch, all_weight_samples[global_idx], weight_scale)
                mlp, layer_sizes = convert_policy_weights_to_mlp(policy_weights)
                mlp = mlp.to(device)

                # Create policy subdirectory
                policy_dir = os.path.join(arch_dir, f"policy_{sample_idx}")
                os.makedirs(policy_dir, exist_ok=True)

                # Save policy
                policy_path = os.path.join(policy_dir, f"policy_{sample_idx}.pt")
                torch.save(mlp, policy_path)

                # Evaluate with video
                video_path = os.path.join(policy_dir, f"policy_{sample_idx}")
                print(f"  Evaluating: policy_{arch_name}/policy_{sample_idx}...")
                metrics = evaluate_policy_with_video(mlp, env_name, video_path, eval_steps, device, seed=seed, num_envs=50)

                result = {
                    'arch_folder': f'policy_{arch_name}',
                    'policy_name': f'policy_{sample_idx}',
                    **metrics
                }
                all_results.append(result)

                print(f"    success_once={metrics['success_once']:.0%}, "
                      f"success_at_end={metrics['success_at_end']:.0%}, "
                      f"return={metrics['mean_return']:.2f}")

    else:
        # === Joint Mode: Joint sampling p(w,a) ===
        n = config['sample']['num_samples']

        shape = {
            "architecture": (n, arch_max_layer, arch_n_vocab),
            "weight": (n, arch_max_layer - 1, 1, weight_max_size, weight_max_size + patch_size)
        }

        print(f"Sampling {n} policies ({config['sample']['num_sampling_steps']} steps)...")

        model_kwargs = dict()
        samples = diffusion.monl_joint_p_sample_loop(
            model,
            shape,
            clip_denoised=False,
            model_kwargs=model_kwargs,
            progress=True,
            device=device
        )

        all_arch_samples = samples["architecture"]
        all_weight_samples = samples["weight"]

        # Track architecture counters for naming
        arch_counters = {}

        print(f"\nGenerating and evaluating policies...")
        for i in range(n):
            print(f"Processing policy {i+1}/{n}...")

            current_arch = extract_architecture(env_name, all_arch_samples[i])
            if current_arch is None:
                print(f"  Skipping policy {i} - invalid architecture")
                continue

            policy_weights = extract_weights(current_arch, all_weight_samples[i], weight_scale)
            mlp, layer_sizes = convert_policy_weights_to_mlp(policy_weights)
            mlp = mlp.to(device)

            # Create arch folder name from hidden layers
            arch_name = '_'.join(str(x) for x in layer_sizes[1:-1])
            arch_dir = os.path.join(save_dir, f'policy_{arch_name}')
            os.makedirs(arch_dir, exist_ok=True)

            # Get next index for this architecture
            if arch_name not in arch_counters:
                arch_counters[arch_name] = 0
            policy_idx = arch_counters[arch_name]
            arch_counters[arch_name] += 1

            # Create policy subdirectory
            policy_dir = os.path.join(arch_dir, f'policy_{policy_idx}')
            os.makedirs(policy_dir, exist_ok=True)

            # Save policy
            policy_path = os.path.join(policy_dir, f'policy_{policy_idx}.pt')
            torch.save(mlp, policy_path)

            # Evaluate with video
            video_path = os.path.join(policy_dir, f"policy_{policy_idx}")
            print(f"  Evaluating: policy_{arch_name}/policy_{policy_idx}...")
            metrics = evaluate_policy_with_video(mlp, env_name, video_path, eval_steps, device, seed=seed,num_envs=50)

            result = {
                'arch_folder': f'policy_{arch_name}',
                'policy_name': f'policy_{policy_idx}',
                **metrics
            }
            all_results.append(result)

            print(f"    success_once={metrics['success_once']:.0%}, "
                  f"success_at_end={metrics['success_at_end']:.0%}, "
                  f"return={metrics['mean_return']:.2f}")

    # Build and save metrics.json
    if all_results:
        metrics_json = build_metrics_json(config, mode, all_results)
        metrics_path = os.path.join(save_dir, 'metrics.json')
        with open(metrics_path, 'w') as f:
            json.dump(metrics_json, f, indent=2)

        # Print summary
        print("\n" + "=" * 70)
        print("RESULTS SUMMARY (All Policies)")
        print("=" * 70)
        print(f"Mode: {mode}")
        print(f"Total policies: {metrics_json['summary']['total_policies']}")
        print(f"Mean Return:         {metrics_json['summary']['mean_return']:.2f} +/- {metrics_json['summary']['std_return']:.2f}")
        print(f"Min/Max Return:      {metrics_json['summary']['min_return']:.2f} / {metrics_json['summary']['max_return']:.2f}")
        print(f"Mean Success Once:   {metrics_json['summary']['mean_success_once']:.1%}")
        print(f"Mean Success at End: {metrics_json['summary']['mean_success_at_end']:.1%}")

        print("\n" + "-" * 70)
        print("TOP 10 SUMMARY")
        print("-" * 70)
        print(f"Mean Return:         {metrics_json['top_10_summary']['mean_return']:.2f} +/- {metrics_json['top_10_summary']['std_return']:.2f}")
        print(f"Min/Max Return:      {metrics_json['top_10_summary']['min_return']:.2f} / {metrics_json['top_10_summary']['max_return']:.2f}")
        print(f"Mean Success Once:   {metrics_json['top_10_summary']['mean_success_once']:.1%}")
        print(f"Mean Success at End: {metrics_json['top_10_summary']['mean_success_at_end']:.1%}")

        print(f"\nTOP 10 POLICIES:")
        for i, p in enumerate(metrics_json['top_10']):
            print(f"  {i+1}. {p['arch_folder']}/{p['policy_name']}: "
                  f"return={p['mean_return']:.2f}, success_at_end={p['success_at_end']:.0%}")

        print(f"\nSaved to: {save_dir}/")
        print(f"  - Policy files: policy_*/policy_*/policy_*.pt")
        print(f"  - Videos: policy_*/policy_*/policy_*.mp4")
        print(f"  - Metrics: metrics.json")
    else:
        print("\nNo valid policies generated!")


if __name__ == "__main__":
    # Setup argument parser
    parser = argparse.ArgumentParser(description="Sample from trained NiT model with evaluation")
    parser.add_argument("--config", type=str, default="sample_joint",
                        help="Configuration name (without .yaml extension)")
    args = parser.parse_args()

    # Construct config path
    config_path = f'configs/{args.config}.yaml'

    # Check if config file exists
    if not os.path.exists(config_path):
        print(f"Error: Configuration file '{config_path}' not found.")
        exit(1)

    # Load configuration from YAML file
    print(f"Loading configuration from: {config_path}")
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)

    # Call main with the configuration dictionary
    main(config)