import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import networkx as nx
import numpy as np
from pathlib import Path
import pickle
from tqdm import tqdm
import random
from dataclasses import dataclass, field
from typing import Literal

from diversity import compute_coulomb_loss_ensemble
from distances_spectral import (
    FeatureConfig,
    batched_features_for_ensemble,
    pairwise_distance,
    print_direction_info,
    DEFAULT_FEATURE_FLAGS,
    PATTERN_NAMES,
)
from postprocess_new import (
    SelectionConfig,
    SelectionResult,
    run_selection,
    DEFAULT_TEMPERATURES,
)

TOTAL_GRAPHS = 10_000


@dataclass
class UnifiedConfig:
    mode: Literal["both", "train_only", "postprocess_only"] = "both"
    input_probs_path: str | None = None
    num_ensembles: int = 80
    n_vertices: int = 16
    graphs_per_ensemble: int | None = None
    latent_dim: int = 16
    hidden_dim: int = 256
    num_hidden: int = 6
    num_samples_ste: int = 25
    batch_size: int = 100
    num_iterations: int = 5000
    learning_rate: float = 0.00023995817958282425
    lr_decay: bool = False
    gamma_min: float = 0.2
    gamma_max: float = 0.2
    epsilon: float = 1e-3
    feature_flags: dict[str, bool] | None = None
    projection_dim: int = 4
    selection_metrics: list[str] = field(default_factory=lambda: ["all"])
    k_select: int = 100
    selection_objective: Literal["energy", "average"] = "average"
    selection_budget: int | None = None
    temperatures: list[float] = field(default_factory=lambda: DEFAULT_TEMPERATURES.copy())
    include_rigid: bool = True
    start_mode: Literal["anchors", "random", "dissimilar"] = "dissimilar"
    postprocess_strategy: Literal["default", "iterative_survival"] = "iterative_survival"
    sampling_budget: int = 100000
    n_top_matrices: int = 1000
    use_continuous_temp: bool = True
    temp_range: tuple[float, float] = (0.01, 10.0)
    rigid_prob: float = 0.1
    memory_efficient: bool = False
    chunk_size: int | None = None
    orca_path: str = "evaluate/orca/orca"
    output_dir: str = "data"
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    seed: int = 0
    direction_seed: int = 42
    n_jobs: int = -1

    def __post_init__(self):
        if self.graphs_per_ensemble is None:
            self.graphs_per_ensemble = int(np.ceil(TOTAL_GRAPHS / self.num_ensembles))
        if self.memory_efficient and self.chunk_size is None:
            bytes_per_graph = self.n_vertices * self.n_vertices * 4
            target_chunk_memory = 2 * 1024 * 1024 * 1024
            self.chunk_size = max(10, min(500, target_chunk_memory // bytes_per_graph))

    @property
    def sample_interval(self) -> int:
        samples_needed = int(np.ceil(self.graphs_per_ensemble / self.batch_size))
        return max(1, self.num_iterations // samples_needed)

    @property
    def samples_needed(self) -> int:
        return int(np.ceil(self.graphs_per_ensemble / self.batch_size))


DEFAULT_CONFIG = UnifiedConfig()


class EnsembleLinear(nn.Module):
    def __init__(self, in_features, out_features, num_ensembles, bias=True):
        super().__init__()
        self.num_ensembles = num_ensembles
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(num_ensembles, in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(num_ensembles, 1, out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=5**0.5)
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[0])
            bound = 1 / (fan_in**0.5)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        output = torch.bmm(input, self.weight)
        if self.bias is not None:
            output += self.bias
        return output


class EnsembleGraphGenerator(nn.Module):
    def __init__(self, num_ensembles, latent_dim, hidden_dim, num_hidden, n_vertices, num_particles):
        super().__init__()
        self.num_ensembles = num_ensembles
        self.n_vertices = n_vertices
        self.n_out = n_vertices * n_vertices
        self.id_embedding = nn.Parameter(torch.randn(num_ensembles, num_particles, hidden_dim))
        self.z_project = EnsembleLinear(latent_dim, hidden_dim, num_ensembles)
        self.layers = nn.ModuleList()
        for _ in range(num_hidden):
            self.layers.append(EnsembleLinear(hidden_dim, hidden_dim, num_ensembles))
        self.output_layer = EnsembleLinear(hidden_dim, self.n_out, num_ensembles)

    def forward(self, z, particle_ids):
        B = z.shape[0]
        E = self.num_ensembles
        z_ens = z.unsqueeze(0).expand(E, B, -1)
        id_embeds = self.id_embedding[:, particle_ids, :]
        x = self.z_project(z_ens) + id_embeds
        x = F.relu(x)
        for layer in self.layers:
            x = F.relu(layer(x))
        x = self.output_layer(x)
        x = x.view(E, B, self.n_vertices, self.n_vertices)
        x = (x + x.transpose(-1, -2)) / 2
        probs = torch.sigmoid(x)
        mask = torch.eye(self.n_vertices, device=x.device).unsqueeze(0).unsqueeze(0)
        probs = probs * (1 - mask)
        return probs


def train_ensemble(
    config: UnifiedConfig,
    feature_config: FeatureConfig,
) -> tuple[list[list[np.ndarray]], list[list[np.ndarray]]]:
    device = config.device
    model = EnsembleGraphGenerator(
        num_ensembles=config.num_ensembles,
        latent_dim=config.latent_dim,
        hidden_dim=config.hidden_dim,
        num_hidden=config.num_hidden,
        n_vertices=config.n_vertices,
        num_particles=config.batch_size
    ).to(device)
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)
    z_fixed = torch.randn(config.batch_size, config.latent_dim, device=device)
    particle_ids = torch.arange(config.batch_size, device=device)
    collected_graphs: list[list[np.ndarray]] = [[] for _ in range(config.num_ensembles)]
    collected_probs: list[list[np.ndarray]] = [[] for _ in range(config.num_ensembles)]

    print(f"Training ensemble of {config.num_ensembles} generators on {device}")
    print(f"Target: {config.graphs_per_ensemble} graphs per ensemble")
    print(f"Sampling every {config.sample_interval} iterations")
    print()
    print_direction_info(feature_config, config.num_ensembles, torch.device(device))

    pbar = tqdm(range(config.num_iterations), desc="Training")
    for iteration in pbar:
        gamma = config.gamma_min + (
            (config.gamma_max - config.gamma_min) * iteration / config.num_iterations
        )
        if config.lr_decay:
            effective_lr = config.learning_rate / (1 + gamma)
            for param_group in optimizer.param_groups:
                param_group['lr'] = effective_lr
        optimizer.zero_grad()
        adj_stacked = model(z_fixed, particle_ids)
        adj_expanded = adj_stacked.unsqueeze(0).expand(config.num_samples_ste, -1, -1, -1, -1)
        adj_hard = torch.bernoulli(adj_expanded)
        adj_ste = adj_expanded + (adj_hard - adj_expanded).detach()

        ensemble_losses = []
        for ens_idx in range(config.num_ensembles):
            ens_samples = adj_ste[:, ens_idx].reshape(-1, config.n_vertices, config.n_vertices)
            loss = compute_coulomb_loss_ensemble(
                ens_samples,
                gamma=gamma,
                ensemble_id=ens_idx,
                epsilon=config.epsilon,
                just_dist=False,
                num_ensembles=config.num_ensembles,
                feature_config=feature_config,
                proj_dim=config.projection_dim,
            )
            ensemble_losses.append(loss)

        total_loss = torch.stack(ensemble_losses).sum()
        total_loss.backward()
        optimizer.step()

        with torch.no_grad():
            adj_discrete = torch.bernoulli(adj_stacked)

        if iteration != 0 and (iteration % config.sample_interval == 0 or iteration == config.num_iterations - 1):
            for ens_idx in range(config.num_ensembles):
                if len(collected_graphs[ens_idx]) * config.batch_size < config.graphs_per_ensemble:
                    collected_graphs[ens_idx].append(adj_discrete[ens_idx].cpu().numpy())
                    collected_probs[ens_idx].append(adj_stacked[ens_idx].detach().cpu().numpy())

        avg_loss = sum(l.item() for l in ensemble_losses) / config.num_ensembles
        pbar.set_postfix({"loss": f"{avg_loss:.4f}"})

    return collected_graphs, collected_probs


def save_intermediate(
    collected_graphs: list[list[np.ndarray]],
    collected_probs: list[list[np.ndarray]],
    config: UnifiedConfig,
) -> tuple[list[nx.Graph], list[np.ndarray]]:
    output_dir = Path(config.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    all_graphs = []
    all_probs = []

    if config.memory_efficient:
        print(f"Memory-efficient mode: saving probability matrices incrementally...")
        probs_path = output_dir / f"diverse_probs_{config.n_vertices}.pkl"
        chunk_size = config.chunk_size or 100
        current_chunk = []
        all_chunks = []
        total_saved = 0
        chunk_idx = 0

        for ens_idx in tqdm(range(config.num_ensembles), desc="Processing ensembles"):
            prob_batches = collected_probs[ens_idx]
            if not prob_batches:
                print(f"WARNING: Ensemble {ens_idx} has no collected probs")
                continue
            K = config.graphs_per_ensemble
            ens_count = 0
            for batch in prob_batches:
                for prob_matrix in batch:
                    if ens_count >= K:
                        break
                    current_chunk.append(prob_matrix)
                    ens_count += 1
                    total_saved += 1
                    if len(current_chunk) >= chunk_size:
                        chunk_path = output_dir / f"probs_chunk_{chunk_idx}.pkl"
                        with open(chunk_path, "wb") as f:
                            pickle.dump(current_chunk, f)
                        all_chunks.append(chunk_path)
                        current_chunk = []
                        chunk_idx += 1
                if ens_count >= K:
                    break
            collected_graphs[ens_idx] = []
            collected_probs[ens_idx] = []

        if current_chunk:
            chunk_path = output_dir / f"_probs_chunk_{chunk_idx}.pkl"
            with open(chunk_path, "wb") as f:
                pickle.dump(current_chunk, f)
            all_chunks.append(chunk_path)
            current_chunk = []

        bytes_per_prob = config.n_vertices * config.n_vertices * 4
        total_bytes = bytes_per_prob * total_saved
        can_fit_in_memory = total_bytes < 10 * 1024 * 1024 * 1024

        if can_fit_in_memory:
            print(f"Combining {len(all_chunks)} chunks into final file...")
            all_probs_combined = []
            for chunk_path in tqdm(all_chunks, desc="Loading chunks"):
                with open(chunk_path, "rb") as f:
                    chunk_data = pickle.load(f)
                all_probs_combined.extend(chunk_data)
                chunk_path.unlink()
            with open(probs_path, "wb") as f:
                pickle.dump(all_probs_combined, f)
            print(f"Saved {total_saved} probability matrices to {probs_path}")
            return [], all_probs_combined
        else:
            print(f"Estimated size: {total_bytes / (1024**3):.1f} GB - keeping as chunks")
            chunks_meta = {
                "chunk_paths": [str(p) for p in all_chunks],
                "total_count": total_saved,
                "n_vertices": config.n_vertices,
            }
            chunks_meta_path = output_dir / f"diverse_probs_{config.n_vertices}_chunks.pkl"
            with open(chunks_meta_path, "wb") as f:
                pickle.dump(chunks_meta, f)
            print(f"Saved {total_saved} probability matrices in {len(all_chunks)} chunks")
            print(f"Chunks metadata: {chunks_meta_path}")
            return [], str(chunks_meta_path)

    for ens_idx in range(config.num_ensembles):
        batches = collected_graphs[ens_idx]
        prob_batches = collected_probs[ens_idx]
        if not batches:
            print(f"WARNING: Ensemble {ens_idx} has no collected graphs")
            continue
        full_adj = np.concatenate(batches, axis=0)
        full_probs = np.concatenate(prob_batches, axis=0)
        K = config.graphs_per_ensemble
        if len(full_adj) > K:
            full_adj = full_adj[:K]
            full_probs = full_probs[:K]
        graphs = [nx.from_numpy_array(adj) for adj in full_adj]
        all_graphs.extend(graphs)
        all_probs.append(full_probs)

    graph_path = output_dir / f"diverse_graphs_{config.n_vertices}.pkl"
    with open(graph_path, "wb") as f:
        pickle.dump(all_graphs, f)
    print(f"Saved {len(all_graphs)} graphs to {graph_path}")

    if all_probs:
        combined_probs = np.concatenate(all_probs, axis=0)
        probs_path = output_dir / f"diverse_probs_{config.n_vertices}.pkl"
        with open(probs_path, "wb") as f:
            pickle.dump(combined_probs, f)
        print(f"Saved {len(combined_probs)} probability matrices to {probs_path}")
    else:
        combined_probs = []

    return all_graphs, list(combined_probs)


def run_all_selections(
    prob_matrices: list[np.ndarray] | str,
    config: UnifiedConfig,
) -> dict[str, SelectionResult]:
    if "all" in config.selection_metrics:
        metrics = ["gcd", "netlsd_heat", "netlsd_wave", "portrait_div"]
    else:
        metrics = config.selection_metrics

    results = {}
    for metric in metrics:
        print(f"\n{'=' * 60}")
        print(f"RUNNING SELECTION FOR: {metric}")
        print(f"{'=' * 60}")

        selection_config = SelectionConfig(
            k_select=config.k_select,
            metric=metric,
            objective=config.selection_objective,
            strategy=config.postprocess_strategy,
            total_budget=config.selection_budget,
            temperatures=config.temperatures,
            include_rigid=config.include_rigid,
            start_mode=config.start_mode,
            sampling_budget=config.sampling_budget,
            n_top_matrices=config.n_top_matrices,
            use_continuous_temp=config.use_continuous_temp,
            temp_range=config.temp_range,
            rigid_prob=config.rigid_prob,
            orca_path=config.orca_path,
            n_jobs=config.n_jobs,
            random_seed=config.seed,
            verbose=True,
        )

        if config.memory_efficient:
            selection_config.chunk_size = config.chunk_size

        if config.memory_efficient:
            print(f"Using MEMORY-EFFICIENT mode (chunk_size={config.chunk_size})")
            from postprocess_large import run_selection_large
            result = run_selection_large(prob_matrices, selection_config)
        else:
            result = run_selection(prob_matrices, selection_config)

        results[metric] = result

        output_dir = Path(config.output_dir) / "selected"
        output_dir.mkdir(parents=True, exist_ok=True)
        output_path = output_dir / f"unified_{metric}_{config.selection_objective}.pkl"
        with open(output_path, "wb") as f:
            pickle.dump(result.graphs, f)
        print(f"Saved {len(result.graphs)} graphs to {output_path}")

    return results


def run_unified_pipeline(config: UnifiedConfig | None = None) -> dict[str, SelectionResult] | None:
    if config is None:
        config = DEFAULT_CONFIG

    if config.mode == "postprocess_only":
        if config.input_probs_path is None:
            raise ValueError("postprocess_only mode requires input_probs_path to be set")
        if not Path(config.input_probs_path).exists():
            raise ValueError(f"input_probs_path does not exist: {config.input_probs_path}")

    if config.seed != 0:
        torch.manual_seed(config.seed)
        random.seed(config.seed)
        np.random.seed(config.seed)

    feature_config = FeatureConfig(config.feature_flags, direction_seed=config.direction_seed)

    print("=" * 70)
    print("UNIFIED GRAPH GENERATION + SELECTION PIPELINE")
    print("=" * 70)
    print(f"Mode: {config.mode}")
    if config.memory_efficient:
        print(f"Memory-efficient: YES (chunk_size={config.chunk_size})")
    else:
        print(f"Memory-efficient: NO")

    if config.mode in ["both", "train_only"]:
        print(f"\n=== Generation Config ===")
        print(f"Ensembles: {config.num_ensembles}")
        print(f"Vertices: {config.n_vertices}")
        print(f"Graphs per ensemble: {config.graphs_per_ensemble}")
        print(f"Total graphs (budget): {TOTAL_GRAPHS}")
        print(f"Iterations: {config.num_iterations}")
        print(f"Device: {config.device}")
        print(f"\n=== Feature Config ===")
        feature_config.print_info()

        print("\n" + "=" * 70)
        print("PHASE 1: TRAINING ENSEMBLE")
        print("=" * 70)
        collected_graphs, collected_probs = train_ensemble(config, feature_config)

        print("\n" + "=" * 70)
        print("PHASE 2: SAVING INTERMEDIATE RESULTS")
        print("=" * 70)
        all_graphs, all_probs = save_intermediate(collected_graphs, collected_probs, config)

        if config.mode == "train_only":
            print("\n" + "=" * 70)
            print("TRAIN ONLY MODE - COMPLETE")
            print("=" * 70)
            if isinstance(all_probs, str):
                print(f"Saved graphs and probability matrices (chunked)")
            else:
                print(f"Saved {len(all_graphs)} graphs and {len(all_probs)} probability matrices")
            print(f"Output directory: {config.output_dir}")
            print("\nDone!")
            return None

    if config.mode == "postprocess_only":
        print(f"\n=== Loading Probabilities ===")
        print(f"Input path: {config.input_probs_path}")
        with open(config.input_probs_path, "rb") as f:
            all_probs = pickle.load(f)
        if isinstance(all_probs, np.ndarray):
            all_probs = list(all_probs)
        print(f"Loaded {len(all_probs)} probability matrices")
        print(f"Shape: {all_probs[0].shape}")
        if all_probs[0].shape != (config.n_vertices, config.n_vertices):
            raise ValueError(
                f"Loaded probs have shape {all_probs[0].shape}, "
                f"but config.n_vertices={config.n_vertices}"
            )

    print(f"\n=== Selection Config ===")
    print(f"Metrics: {config.selection_metrics}")
    print(f"K select: {config.k_select}")
    print(f"Objective: {config.selection_objective}")
    print(f"Temperatures: {len(config.temperatures)} values from {min(config.temperatures)} to {max(config.temperatures)}")
    print(f"Include rigid: {config.include_rigid}")
    print(f"Start mode: {config.start_mode}")
    if config.memory_efficient:
        print(f"Memory-efficient: YES (chunk_size={config.chunk_size})")
    print()

    print("=" * 70)
    print("PHASE 3: SELECTION")
    print("=" * 70)
    results = run_all_selections(all_probs, config)

    print("\n" + "=" * 70)
    print("FINAL SUMMARY")
    print("=" * 70)
    for metric, result in results.items():
        print(f"\n{metric}:")
        print(f"  Diversity: {result.diversity:.6f}")
        print(f"  Energy: {result.energy:.6f}")
        print(f"  Pool size: {result.pool_size}")
        print(f"  Edge stats: mean={result.edge_mean:.1f}, std={result.edge_std:.1f}, "
              f"range=[{result.edge_min}, {result.edge_max}]")
    print("\nDone!")
    return results


if __name__ == "__main__":
    # Default: train and select in one run
    config = UnifiedConfig(
        mode="both",
        num_ensembles=80,
        n_vertices=16,
        num_iterations=5000,
        selection_metrics=["all"],
        k_select=100,
        selection_objective="average",
        output_dir="data",
    )

    # # Postprocess only (load saved probability matrices)
    # config = UnifiedConfig(
    #     mode="postprocess_only",
    #     input_probs_path="data/diverse_probs_16.pkl",
    #     n_vertices=16,
    #     selection_metrics=["all"],
    #     k_select=100,
    #     output_dir="data/selected",
    # )

    # # Train only (save probs for later postprocessing)
    # config = UnifiedConfig(
    #     mode="train_only",
    #     num_ensembles=80,
    #     n_vertices=16,
    #     num_iterations=5000,
    #     output_dir="data",
    # )

    results = run_unified_pipeline(config)
