"""
Minimal LAMP: Linear Additive Model Prediction for 3D Shape Generation
Simplified implementation for controlled 3D car generation
"""

import os
import sys
import json
import argparse
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from scipy.optimize import minimize
from typing import Dict, List, Tuple, Optional
import traceback

# Local imports
from models.mlp_models import MLP3D
from mesh_gen.sdf_meshing import create_mesh


def detect_experiment_column(df: pd.DataFrame) -> str:
    """Detect the experiment column in the CSV"""
    for col in df.columns:
        if 'experiment' in col.lower():
            return col
    # Return first non-numeric column
    for col in df.columns:
        if not pd.api.types.is_numeric_dtype(df[col]):
            return col
    return df.columns[0]


def pick_param_columns(df: pd.DataFrame, experiment_col: str) -> List[str]:
    """Select parameter columns from the CSV"""
    params = []
    for col in df.columns:
        if col == experiment_col:
            continue
        lc = col.lower()
        if "std" in lc or "average" in lc:
            continue
        if pd.api.types.is_numeric_dtype(df[col]):
            params.append(col)
    return params


def flatten_weights(path: str) -> Tuple[np.ndarray, dict]:
    """Load and flatten PyTorch weights"""
    try:
        data = torch.load(path, map_location="cpu")
        if isinstance(data, dict):
            # Standard state dict
            flat = torch.cat([v.flatten() for v in data.values() if isinstance(v, torch.Tensor)])
        else:
            # Already flattened tensor
            flat = data.flatten()
        return flat.cpu().numpy(), data
    except Exception as e:
        raise RuntimeError(f"Failed to load weights from {path}: {e}")


def assign_flat_params(model: torch.nn.Module, flat_tensor: torch.Tensor):
    """Assign flattened weights to model parameters"""
    idx = 0
    with torch.no_grad():
        for param in model.parameters():
            numel = param.numel()
            param_data = flat_tensor[idx:idx + numel].view_as(param)
            param.copy_(param_data)
            idx += numel


def parse_targets(s: str) -> Dict[str, float]:
    """Parse target string 'param1=val1,param2=val2' into dict"""
    targets = {}
    if not s:
        return targets
    for kv in s.split(","):
        kv = kv.strip()
        if not kv or "=" not in kv:
            continue
        k, v = kv.split("=", 1)
        k, v = k.strip(), v.strip()
        if k:
            try:
                targets[k] = float(v)
            except ValueError:
                print(f"Warning: Could not parse {v} as float for {k}")
    return targets


def load_experiments(df: pd.DataFrame, logs_root: str, weight_file: str, 
                    experiment_col: str, param_cols: List[str]) -> Tuple[np.ndarray, np.ndarray, List[str]]:
    """Load experiment data and weights"""
    weights_list, param_rows, valid_exps = [], [], []
    
    for _, row in df.iterrows():
        exp = str(row[experiment_col])
        weight_path = os.path.join(logs_root, exp, weight_file)
        
        if not os.path.exists(weight_path):
            continue
            
        try:
            flat_weights, _ = flatten_weights(weight_path)
            param_values = [float(row[col]) for col in param_cols]
            
            weights_list.append(flat_weights)
            param_rows.append(param_values)
            valid_exps.append(exp)
        except Exception:
            continue
    
    if not weights_list:
        raise RuntimeError("No valid experiments with loadable weights found.")
    
    X = np.vstack(param_rows)  # [N, P] parameters  
    W = np.vstack(weights_list)  # [N, D] weights
    
    print(f"✅ Loaded {len(valid_exps)} experiments")
    print(f"   Parameter matrix: {X.shape}")
    print(f"   Weight matrix: {W.shape}")
    
    return X, W, valid_exps


def solve_mixing_weights(X: np.ndarray, targets: Dict[str, float], param_cols: List[str]) -> np.ndarray:
    """Solve for mixing weights using constrained optimization"""
    # Find constrained parameters
    constrained_params = [p for p in param_cols if p in targets]
    if not constrained_params:
        raise ValueError("No target parameters match CSV columns")
    
    # Build constraint matrix and target vector
    indices = [param_cols.index(p) for p in constrained_params]
    X_sub = X[:, indices]  # [N, K] where K is number of constraints
    y = np.array([targets[p] for p in constrained_params])  # [K]
    
    N = X_sub.shape[0]
    
    def objective(alpha):
        residual = X_sub.T @ alpha - y
        return float(np.dot(residual, residual))
    
    # Constraint: sum of alphas = 1
    constraint = {'type': 'eq', 'fun': lambda a: np.sum(a) - 1.0}
    alpha_init = np.full(N, 1.0 / N)
    
    result = minimize(
        objective, alpha_init, 
        method='SLSQP', 
        constraints=constraint,
        options={'maxiter': 10000, 'ftol': 1e-10}
    )
    
    if not result.success:
        print(f"⚠️  Optimization warning: {result.message}")
    
    print(f"✅ Optimization completed:")
    print(f"   Iterations: {result.nit}")
    print(f"   Final error: {result.fun:.6f}")
    print(f"   ||alpha||₂: {np.linalg.norm(result.x):.6f}")
    
    return result.x


def mix_weights(W: np.ndarray, alpha: np.ndarray) -> np.ndarray:
    """Mix weights using GPU acceleration if available"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    W_t = torch.tensor(W, dtype=torch.float32, device=device)
    a_t = torch.tensor(alpha, dtype=torch.float32, device=device)
    
    mixed = (a_t @ W_t).detach().cpu().numpy()
    
    # Cleanup GPU memory
    del W_t, a_t
    if device.type == "cuda":
        torch.cuda.empty_cache()
        
    return mixed


def generate_mesh(mixed_weights: np.ndarray, output_path: str, resolution: int = 256):
    """Generate mesh from mixed weights"""
    print(f"🧱 Generating mesh at resolution {resolution}^3...")
    
    # Create MLP model with standard architecture
    model = MLP3D(
        out_size=1,
        hidden_neurons=[1024, 1024, 1024],
        use_leaky_relu=False,
        multires=4
    )
    
    # Assign mixed weights to model
    flat_tensor = torch.tensor(mixed_weights, dtype=torch.float32)
    assign_flat_params(model, flat_tensor)
    model.eval()
    
    # Generate mesh
    vertices, faces, sdf_tensor = create_mesh(
        model, 
        filename=output_path,
        N=resolution,
        level=0.0
    )
    
    return vertices, faces


def print_parameter_stats(X: np.ndarray, param_cols: List[str]):
    """Print parameter statistics"""
    print(f"\n📊 Parameter Statistics ({X.shape[0]} experiments, {len(param_cols)} parameters):")
    print("=" * 80)
    
    for i, param in enumerate(param_cols):
        values = X[:, i]
        print(f"{param}:")
        print(f"  Range: [{values.min():.4f}, {values.max():.4f}]")
        print(f"  Mean ± Std: {values.mean():.4f} ± {values.std():.4f}")
        # print(f"  Median: {values.median():.4f}")
        print()


def main():
    parser = argparse.ArgumentParser(description="Minimal LAMP: Controlled 3D Shape Generation")
    
    # Required arguments
    parser.add_argument("--csv", required=True, help="CSV file with experiments and parameters")
    parser.add_argument("--logs_root", default="../HyperDiffusion-main/logs", help="Root directory with trained models")
    parser.add_argument("--targets", required=True, help="Target parameters as 'param1=val1,param2=val2'")
    
    # Optional arguments
    parser.add_argument("--weight_file", default="occ_model_jitter_0_model_final.pt", help="Weight file name")
    parser.add_argument("--output_dir", default="output", help="Output directory") 
    parser.add_argument("--output_name", default="generated_car", help="Output file name")
    parser.add_argument("--resolution", type=int, default=256, help="Mesh resolution")
    parser.add_argument("--verbose", action="store_true", help="Verbose output")
    
    args = parser.parse_args()
    
    # Ensure output directory exists
    os.makedirs(args.output_dir, exist_ok=True)
    
    try:
        print("🎯 Minimal LAMP: Neural Weight Mixing for 3D Generation")
        print("=" * 60)
        
        # Load data
        print(f"📂 Loading data from {args.csv}...")
        df = pd.read_csv(args.csv)
        exp_col = detect_experiment_column(df)
        param_cols = pick_param_columns(df, exp_col)
        
        print(f"   Experiment column: {exp_col}")
        print(f"   Parameter columns: {len(param_cols)}")
        
        # Parse targets
        targets = parse_targets(args.targets)
        print(f"   Target parameters: {targets}")
        
        # Load experiments
        X, W, valid_experiments = load_experiments(
            df, args.logs_root, args.weight_file, exp_col, param_cols
        )
        
        if args.verbose:
            print_parameter_stats(X, param_cols)
        
        # Solve for mixing weights
        print(f"⚡ Solving for optimal mixing weights...")
        alpha = solve_mixing_weights(X, targets, param_cols)
        
        # Mix weights
        print(f"🔄 Mixing neural network weights...")
        mixed_weights = mix_weights(W, alpha)
        
        # Generate mesh
        output_base = os.path.join(args.output_dir, args.output_name)
        vertices, faces = generate_mesh(mixed_weights, output_base, args.resolution)
        
        # Save results
        results = {
            "targets": targets,
            "num_experiments": len(valid_experiments),
            "parameter_dims": len(param_cols),
            "weight_dims": len(mixed_weights),
            "mesh_vertices": len(vertices) if vertices is not None else 0,
            "mesh_faces": len(faces) if faces is not None else 0,
            "resolution": args.resolution,
            "output_files": [f"{args.output_name}.ply"]
        }
        
        results_path = os.path.join(args.output_dir, f"{args.output_name}_results.json")
        with open(results_path, 'w') as f:
            json.dump(results, f, indent=2)
        
        print(f"\n✅ Generation completed successfully!")
        print(f"   Mesh: {output_base}.ply")
        print(f"   Results: {results_path}")
        print(f"   Vertices: {len(vertices) if vertices is not None else 0}")
        print(f"   Faces: {len(faces) if faces is not None else 0}")
        
    except Exception as e:
        print(f"❌ Error: {e}")
        if args.verbose:
            traceback.print_exc()
        sys.exit(1)


if __name__ == "__main__":
    main()
