# forward_forward/utils/training_utils.py

import torch
import random
import numpy as np
import os
import shutil
from datetime import datetime
from omegaconf import OmegaConf
import torchvision.transforms as T
from torch.utils.data import random_split

from forward_forward.models.model_factory import build_model_from_config
from forward_forward.models.model_factory import build_bp_model_from_config
from forward_forward.data.dataloader import get_base_dataset
from forward_forward.training.trainer import Trainer
from forward_forward.data.transforms import ApplyTransform
from thop import profile, clever_format
from torchsummary import summary



def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def load_dataloaders(cfg: OmegaConf, seed: int):
    
    # mean = [0.4914, 0.4822, 0.4465]
    # std  = [0.2470, 0.2435, 0.2616]
    
    # Load base datasets without any transforms
    base_train = get_base_dataset(cfg.data.dataset, train=True, transform=None)
    base_test = get_base_dataset(cfg.data.dataset, train=False, transform=None)

    # Get input shape from base dataset
    input_shape = base_train.input_shape
    image_size = input_shape[1]  # assumes (C, H, W)

    # Split training data into train/validation
    val_fraction = cfg.get("val_fraction", 0.1)
    n_val = int(len(base_train) * val_fraction)
    n_train = len(base_train) - n_val
    # # TO CHECK: what if we use all data for training after finding best model on val set?
    # n_train = len(base_train)
    # n_val = 0

    train_subset, val_subset = random_split(
        base_train, [n_train, n_val],
        generator=torch.Generator().manual_seed(seed)
    )

    # Common transform for validation and test
    val_transform = T.ToTensor()
    # val_transform = T.Compose([
    #     T.ToTensor(),
    #     # T.Normalize(mean, std),
    # ])
    val_dataset = ApplyTransform(val_subset, transform=val_transform)
    test_dataset = ApplyTransform(base_test, transform=val_transform)


    # Validation/test loaders (same for both modes)
    num_workers = cfg.get("num_workers", 2)
    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=512, shuffle=False, num_workers=num_workers
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=512, shuffle=False, num_workers=num_workers
    )

    print(f"Train dataset size: {n_train}")
    print(f"Validation dataset size: {n_val}")
    print(f"Test dataset size: {len(base_test)}")

    return train_subset, val_loader, test_loader, base_train

def setup_experiment_group(cfg: OmegaConf) -> str:
    """Create experiment group directory and save config/code"""
    # Create base output directory if it doesn't exist
    output_root = cfg.get("output_dir", "experiments")
    os.makedirs(output_root, exist_ok=True)
    
    # Create unique group name with timestamp
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    group_name = f"{cfg.run_name}_{timestamp}"
    group_dir = os.path.join(output_root, group_name)
    os.makedirs(group_dir, exist_ok=True)
    
    # Save config file
    config_path = os.path.join(group_dir, "config.yaml")
    with open(config_path, "w") as f:
        OmegaConf.save(cfg, f)
    
    # Save code snapshot
    code_dir = os.path.join(group_dir, "code")
    os.makedirs(code_dir, exist_ok=True)
    
    # Get the actual project root (where forward_forward, scripts, configs are siblings)
    # Go up until we find the directory that contains all the expected folders
    current_dir = os.path.dirname(os.path.abspath(__file__))
    
    # Check if we're already in the project root or need to go up
    project_root = current_dir
    expected_dirs = ["forward_forward", "scripts", "configs"]
    
    # Try to find the correct project root by checking for expected directories
    max_levels_up = 3  # Don't go too far up
    for _ in range(max_levels_up):
        has_all_dirs = all(os.path.exists(os.path.join(project_root, d)) for d in expected_dirs)
        if has_all_dirs:
            break
        project_root = os.path.dirname(project_root)  # Go up one level
    
    print(f"📁 Project root identified as: {project_root}")
    
    # Copy critical directories
    for folder in expected_dirs:
        src = os.path.join(project_root, folder)
        if os.path.exists(src):
            dst = os.path.join(code_dir, folder)
            shutil.copytree(src, dst, ignore=shutil.ignore_patterns("__pycache__", "*.pyc"))
            print(f"📋 Copied {folder} to code snapshot")
        else:
            print(f"⚠️  Warning: Directory {src} not found, skipping copy")
    
    print(f"✅ Experiment group directory created at: {group_dir}")
    return group_dir

def create_seed_directory(group_dir: str, seed: int) -> str:
    """Create seed-specific directory within experiment group"""
    seed_dir = os.path.join(group_dir, f"seed_{seed}")
    os.makedirs(seed_dir, exist_ok=True)
    return seed_dir

def run_experiment(cfg: OmegaConf, seed: int, group_dir: str):
    """Run a single experiment with a given seed"""
    set_seed(seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if device.type == "cuda":
        print(f"✅ Using GPU: {torch.cuda.get_device_name(device)}")
    else:
        print("❌ Using CPU")

    # Create seed-specific directory
    seed_dir = create_seed_directory(group_dir, seed)
    
    # Load data and model
    train_dataset, val_loader, test_loader, base_train = load_dataloaders(cfg, seed)
    input_shape = base_train.input_shape
    num_classes = base_train.num_classes

    regime = cfg.training.get("regime", "consecutive")
    if regime == "consecutive":
        model = build_model_from_config(cfg.model.architecture, input_shape, num_classes)
    elif regime == "densenet":
        model = DenseNet.densenet_cifar_ff()
    elif regime == "bp":
        model = build_bp_model_from_config(cfg.model.architecture, input_shape, num_classes)
    
    print(f"Model architecture:\n{model}")
    
    
    # use input shape to count flops and params
    input_data = torch.randn(1, *input_shape)
    flops, params = profile(model, inputs=(input_data,))
    flops, params = clever_format([flops, params], "%.3f")
    print(f'FLOPs: {flops}, Params: {params}')
    
    # count also by torch summary
    # RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
    summary(model.to(device), input_size=input_shape)
    
    # breakpoint()

    # Create unique run name with seed
    run_name = f"{cfg.run_name}_seed{seed}"

    # Trainer
    trainer = Trainer(
        model=model,
        train_dataset=train_dataset,
        # epochs=epochs,
        num_classes=num_classes,
        input_shape=input_shape,
        # trained_epochs_so_far=trained_epochs,
        wandb_enabled=cfg.training.wandb,
        device=device,
        cfg=cfg,
        run_name=run_name,
        dataset=cfg.data.dataset,
        output_dir=seed_dir  # Pass seed directory to trainer
    )

    total_epochs = cfg.consecutive_training.get("epochs", 50)
    print(f"🔄 Starting consecutive training for {total_epochs} epochs")
    if regime == "consecutive":
        trainer.train(
            val_dataloader=val_loader,
            test_dataloader=test_loader,
            total_epochs=total_epochs
            )
    elif regime == "bp":
        trainer.bp_train(
            val_dataloader=val_loader,
            test_dataloader=test_loader,
            total_epochs=total_epochs
            )