#!/usr/bin/env python3

from __future__ import annotations

import json
import time
from pathlib import Path
from collections import defaultdict
from typing import Dict, List, Tuple
from custom_optimizers import DFA, RandomNN
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------------------------------------------------------------------------
# helpers.py – only import symbols that actually exist there
# ---------------------------------------------------------------------------
from helpers import (
    DEFAULT_CFG,
    pretty_cfg_diff,
    ensure_dirs,
    choose_optimizer,
    cosine_similarity,
    add_param_noise,
    add_input_noise,
    hebbian_update,
    pre_synaptic_inputs,
    get_dataloaders_model_crit_model,  # model + loaders + criterion in one call
    get_frozen_mask,
    SmallCNN,
)




# ---------------------------------------------------------------------------
# Alignment helpers
# ---------------------------------------------------------------------------

def layer_alignments(
    hebbian_updates: List[torch.Tensor],
    model: nn.Module,
    cfg: Dict,
) -> Tuple[Dict[str, List[float]], Dict[str, List[float]]]:
    """Compute alignment and grad-norm stats for configured layers."""
    aligns, gnorms = defaultdict(list), defaultdict(list)

    if hasattr(model, 'layers'):
        # MLP/Transformer models
        for idx, layer in enumerate(model.layers):
            name = f"L{idx + 1}"

            l2_term = cfg.get("weight_decay", 0.0) * 2 * layer.weight
            pred = hebbian_updates[idx]

            for suffix, extra in [("", 0), ("_wd", l2_term)]:
                compare = layer.weight.grad + extra
                if cfg.get("alignment_rule", "hebb") == "anti_hebb":
                    compare = -compare
                aligns[name + suffix].append(float(cosine_similarity(pred, -compare)))
                gnorms[name + suffix].append(float(layer.weight.norm()))
    
    elif isinstance(model, SmallCNN):
        # CNN models - only align linear layers
        linear_layers = [model.fc1, model.fc2, model.fc3]
        
        for idx, layer in enumerate(linear_layers):
            name = f"L{idx + 1}"

            l2_term = cfg.get("weight_decay", 0.0) * 2 * layer.weight
            pred = hebbian_updates[idx]

            for suffix, extra in [("", 0), ("_wd", l2_term)]:
                compare = layer.weight.grad + extra
                if cfg.get("alignment_rule", "hebb") == "anti_hebb":
                    compare = -compare
                

                
                aligns[name + suffix].append(float(cosine_similarity(pred, -compare)))
                gnorms[name + suffix].append(float(layer.weight.norm()))

    return aligns, gnorms


# ---------------------------------------------------------------------------
# Public: train_one_epoch
# ---------------------------------------------------------------------------

def train_one_epoch(
    model: nn.Module,
    train_loader: torch.utils.data.DataLoader,
    metrics_loader: torch.utils.data.DataLoader,
    optimiser: optim.Optimizer,
    criterion: nn.Module,
    cfg: Dict,
    epoch: int,
    learning_network: None
) -> Tuple[Dict, List[Dict], List[Dict], List[Dict]]:
    model.train()

    tot_loss, tot_correct, n_samples = 0.0, 0, 0
    frac_pos_align, cache_deltas, saved_updates = [], [], []
    best_sim, worst_sim = -float("inf"), -float("inf")
    best_entry = worst_entry = None

    is_regression = isinstance(criterion, torch.nn.MSELoss)
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        
        # Apply input noise if configured
        if cfg.get("input_noise_fraction", 0.0):
            data = add_input_noise(data, cfg["input_noise_fraction"])
        
        optimiser.zero_grad()

        if cfg["optimizer"] in ["RandomNN"]:
            with torch.no_grad():
                logits = model(data)
        
        else:
            logits = model(data)
        # print("logits",logits)
        loss = criterion(logits, target)
        error_signal, = torch.autograd.grad(loss, logits, retain_graph=True)
        
        if cfg["optimizer"] not in ["RandomNN"]:
            loss.backward(retain_graph=True)
            

        

        h_updates = hebbian_update(model, data, target, cfg)
        if cfg["optimizer"] in ["DFA", "RandomNN"]:
            learning_network.step(data, error_signal)

        if cfg.get("phase_alignment", False):
                
            aligns, gnorms = layer_alignments(h_updates, model, cfg)
            frac_pos_align.append({"alignments": aligns, "grad_norms": gnorms})
            sim = float(np.mean(aligns.get("L2", [0])))
            if sim > best_sim:
                best_sim, best_entry = sim, {
                    "epoch": epoch,
                    "alignments": aligns,
                    "hebbian_update": [h.cpu().tolist() for h in h_updates],
                    "gradient": [l.weight.grad.cpu().tolist() for l in (model.layers if hasattr(model, 'layers') else [model.fc1, model.fc2] if isinstance(model, SmallCNN) else [])],
                }
            if abs(sim) < abs(worst_sim):
                worst_sim, worst_entry = sim, {
                    "epoch": epoch,
                    "alignments": aligns,
                    "hebbian_update": [h.cpu().tolist() for h in h_updates],
                    "gradient": [l.weight.grad.cpu().tolist() for l in (model.layers if hasattr(model, 'layers') else [model.fc1, model.fc2] if isinstance(model, SmallCNN) else [])],
                }
        L = 1
        dim = 0 #input neuron
        if cfg.get("gradient_noise_fraction", 0.0):
            add_param_noise(model, cfg["gradient_noise_fraction"])


        if cfg["track_updates"] > 0:
            # Get the appropriate layer for tracking
            if hasattr(model, 'layers'):
                track_layer = model.layers[L]
            elif isinstance(model, SmallCNN):
                track_layer = model.fc1  # Use first linear layer for tracking
            else:
                track_layer = None
            
            if track_layer is not None:
                output_normalized_grad = (track_layer.weight.grad.detach() /  track_layer.weight.grad.norm(dim=dim)).detach()
                per_neuron_grad_mag = track_layer.weight.grad.abs().sum(dim=dim)


        if cfg.get("weight_decay", 0.0) and cfg["regularization_mode"] in ["L2_weight_decay", "L2_weight_decay_bn"]:
            (cfg["weight_decay"] * sum([p.pow(2).sum() for p in model.parameters()])).backward(retain_graph=True)

        if cfg.get("weight_decay", 0.0) and cfg["regularization_mode"] == "L1_weight_decay":
            (cfg["weight_decay"] * sum([p.abs().sum() for p in model.parameters()])).backward(retain_graph=True)

            

        if cfg["track_updates"] > 0 and track_layer is not None:

            per_neuron_update_mag = track_layer.weight.grad.abs().sum(dim=dim).detach()

            output_normalized_update = track_layer.weight.grad.detach() /  track_layer.weight.grad.norm(dim=dim)
            output_normalized_hebbian = h_updates[L] /  h_updates[L].norm(dim=dim)

            neuron_update_cos_sim = F.cosine_similarity(output_normalized_hebbian, -output_normalized_update, dim=dim)

            neuron_gradient_cos_sim = F.cosine_similarity(output_normalized_hebbian, -output_normalized_grad, dim=dim)


            k = cfg["track_updates"]
            if hasattr(model, 'layers'):
                grad_norms = [l.weight.data.norm() for l in model.layers]
            elif isinstance(model, SmallCNN):
                grad_norms = [model.fc1.weight.data.norm(), model.fc2.weight.data.norm(), model.fc3.weight.data.norm()]
            else:
                grad_norms = []

            # cosine_similarity returns a (k,) tensor
            saved_updates.append(
                { 
                    "epoch": epoch,
                    "alignments": aligns,
                    "grad_norm_total": torch.prod(torch.stack(grad_norms)).detach().cpu().tolist(),
                    "neuron_update_cos_sim": neuron_update_cos_sim[:k].detach().cpu().tolist(),
                    "neuron_grad_cos_sim": neuron_gradient_cos_sim[:k].detach().cpu().tolist(),
                    "neuron_update_mag": per_neuron_update_mag[:k].detach().cpu().tolist(),
                    "per_neuron_grad_mag": per_neuron_grad_mag[:k].detach().cpu().tolist()

                }
            )

        if cfg["optimizer"] not in ["Hebb"]:
            # Zero out gradients for frozen weights before optimizer step
            if cfg.get("frozen_ratio", 0.0) > 0.0 or cfg.get("sparsity", 0.0) > 0.0:
                for param in model.parameters():
                    frozen_mask = get_frozen_mask(param, model)
                    if frozen_mask is not None:
                        param.grad[frozen_mask] = 0.0
            
            optimiser.step()

        tot_loss += loss.item()
        if not is_regression:
            tot_correct += logits.argmax(1).eq(target).sum().item()
        n_samples += data.size(0)

    if epoch in cfg.get("cache_delta", []) and best_entry and worst_entry:
        cache_deltas.extend([best_entry, worst_entry])

    # -------- alignment over metrics_loader --------
    grad_norm_sum = defaultdict(list)
    metrics = {"avg_loss": tot_loss / len(train_loader)}
    if not is_regression:
        metrics["train_accuracy"] = 100.0 * tot_correct / n_samples
    else:
        metrics["train_accuracy"] = -1


    return metrics, frac_pos_align, cache_deltas, saved_updates

# ---------------------------------------------------------------------------
# Utility: quick evaluation
# ---------------------------------------------------------------------------

def _evaluate(model: nn.Module, loader, criterion, device) -> Tuple[float, float]:
    model.eval()
    tot_loss, tot_correct, n = 0.0, 0, 0
    is_regression = isinstance(criterion, torch.nn.MSELoss)
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            logits = model(data)

            tot_loss += criterion(logits, target).item()
            if not is_regression:
                tot_correct += logits.argmax(1).eq(target).sum().item()
            n += data.size(0)

    acc = -1
    if not is_regression:
        acc = 100.0 * tot_correct / n

    return tot_loss / len(loader), acc

# ---------------------------------------------------------------------------
# Public: full experiment loop
# ---------------------------------------------------------------------------
 


def evaluate_expirement(settings: Dict) -> List[Dict]:  # original typo kept
    cfg: Dict = {**DEFAULT_CFG, **settings}
    if "experiment_name" not in cfg:
        raise ValueError("settings must include 'experiment_name'.")


    root = Path("results") / cfg["experiment_name"] / pretty_cfg_diff(cfg, DEFAULT_CFG)
    dirs = ensure_dirs(root)
    metrics_dir = root / "metrics" 
    if metrics_dir.exists() and any(metrics_dir.iterdir()):
        print("already run--aborted")
        return []

    #root/metrics

    model, (tl, vl, ml), criterion = get_dataloaders_model_crit_model(cfg)
    model = model.to(device)
    optimiser = choose_optimizer(model.parameters(), cfg)

    history, best_val_acc = [], -float("inf")
    frac_align_all, cache_deltas_all, saved_updates_all = [], [], []

    start = time.time()

    learning_network = None
    if cfg["optimizer"] == "DFA":
        learning_network = DFA(model)
    
    elif cfg["optimizer"] == "RandomNN":
        learning_network = RandomNN(model)

    # Evaluate initial conditions (epoch 0)
    model.eval()
    train_loss, train_acc = _evaluate(model, tl, criterion, device)
    val_loss, val_acc = _evaluate(model, vl, criterion, device)
    
    initial_metrics = {
        "epoch": 0,
        "avg_loss": train_loss,
        "train_accuracy": train_acc,
        "val_loss": val_loss,
        "val_acc": val_acc
    }
    history.append(initial_metrics)
    print(f"[Epoch 0] loss={train_loss:.3f} acc={train_acc:.2f}% | val_loss={val_loss:.3f} val_acc={val_acc:.2f}%")

    with torch.autograd.detect_anomaly():

        for epoch in range(1, cfg["epochs"] + 1):
            metrics, frac_align, cache_deltas, saved_updates = train_one_epoch(
                model, tl, ml, optimiser, criterion, cfg, epoch, learning_network
            )
            val_loss, val_acc = _evaluate(model, vl, criterion, device)
            metrics.update({"epoch": epoch, "val_loss": val_loss, "val_acc": val_acc})
            history.append(metrics)

            frac_align_all.extend(frac_align)
            cache_deltas_all.extend(cache_deltas)
            saved_updates_all.append(saved_updates)

            print(f"[Epoch {epoch}/{cfg['epochs']}] loss={metrics['avg_loss']:.3f} acc={metrics['train_accuracy']:.2f}% | "
                f"val_loss={val_loss:.3f} val_acc={val_acc:.2f}%")

    elapsed = (time.time() - start) / 60
    cfg["elapsed_time"] = elapsed

    # ---- save artefacts ----
    (dirs["metrics"] / "metrics.json").write_text(json.dumps(history, indent=2))
    (dirs["metrics"] / "config.json").write_text(json.dumps(cfg, indent=2))
    (dirs["metrics"] / "cached_deltas.json").write_text(json.dumps(cache_deltas_all, indent=2))
    (dirs["metrics"] / "saved_updates.json").write_text(json.dumps(saved_updates_all, indent=2))
    (dirs["metrics"] / "frac_pos_alignment.json").write_text(json.dumps(frac_align_all, indent=2))
    print(f"✔ Experiment finished in {elapsed:.1f} min – artefacts in {root}")

    return history

# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    import argparse, sys

    p = argparse.ArgumentParser(description="Run an experiment from JSON config")
    p.add_argument("config", help="Path to JSON config file")
    cfg_path = Path(p.parse_args().config)
    if not cfg_path.exists():
        sys.exit(f"Config file '{cfg_path}' not found.")
    evaluate_expirement(json.loads(cfg_path.read_text()))
