import torch
import torch.nn.functional as F
import torchcde
import time
import copy
import os
from datetime import datetime
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from main.analysis.diagnostics import ExperimentDiagnostics

def get_tensorboard_path(cfg, experiment_name, interpolation_type, kernel_params, model):
    base_dir = os.path.join(cfg['runs_dir'], cfg['dataset_name'])
    tol = kernel_params.get('tol')
    if tol is None and hasattr(model, 'tol'):
        tol = model.tol
    if tol is None:
        try:
            parts = experiment_name.split('_')
            for p in parts:
                if p.startswith('tol-'):
                    tol = p.split('-')[1]
                    break
        except:
            tol = 'unknown'
    tol_dir = f"tol-{tol}" if tol != 'unknown' else "tol-unknown"

    if interpolation_type in ['linear', 'cubic']:
        category_path = os.path.join("baseline", tol_dir)
    elif interpolation_type == "log_ncde":
        step = kernel_params.get('step_size', 'N/A')
        depth = kernel_params.get('depth', 'N/A')
        category_path = os.path.join("Log-NCDE", f"depth{depth}_step{step}", tol_dir)
    elif interpolation_type == "kernel":
        k_name = kernel_params.get('kernel', 'unknown_kernel')
        category_path = os.path.join("kernel", k_name, tol_dir)
    elif interpolation_type == "gp": 
        category_path = os.path.join("GP", tol_dir)
    elif interpolation_type == "qformer":
       
        category_path = os.path.join("qformer", tol_dir)
    elif interpolation_type == "conv":
        category_path = os.path.join("conv", tol_dir)
    elif interpolation_type == "odernn":
        category_path = os.path.join("ODE-RNN", tol_dir)
    elif interpolation_type == "grud":
        category_path = os.path.join("GRU-D", tol_dir)
    else:
        category_path = os.path.join("other", tol_dir)

    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    run_name = f"{experiment_name}_{timestamp}"
    return os.path.join(base_dir, category_path, run_name)


def run_experiment(model, data, cfg, logger, experiment_name, interpolation_type, kernel_params={}):
    
    writer_path = get_tensorboard_path(cfg, experiment_name, interpolation_type, kernel_params, model)
    writer = SummaryWriter(log_dir=writer_path)
    logger.info(f"TensorBoard logs will be saved to: {writer_path}")

    train_X, val_X, test_X, train_y, val_y, test_y, t_grid = data
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    logger.info(f"--- Starting Experiment: {experiment_name} ---")
    logger.info(f"Device: {device}, Epochs: {cfg['num_epochs']}, Batch Size: {cfg['batch_size']}")
    logger.info(f"Interpolation: {interpolation_type}, Params: {kernel_params if kernel_params else 'N/A'}")

    optimizer = torch.optim.Adam(model.parameters(), lr=cfg.get("lr", 1e-3),  weight_decay=cfg["weight_decay"])
    
    static_fit_time = 0.0
    
    if interpolation_type == "cubic":
        start_time = time.perf_counter()
        train_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(train_X)
        val_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(val_X)
        test_coeffs_clean = torchcde.hermite_cubic_coefficients_with_backward_differences(test_X)
        static_fit_time = time.perf_counter() - start_time
        logger.info(f"Static Cubic Spline Fit Time: {static_fit_time:.4f}s")
    else: 
        
        train_coeffs, val_coeffs, test_coeffs_clean = train_X, val_X, test_X

    train_dataloader = DataLoader(TensorDataset(train_coeffs, train_y), batch_size=cfg['batch_size'], shuffle=True)
    val_coeffs_dev, val_y_dev = val_coeffs.to(device), val_y.to(device)
    test_coeffs_dev_clean, test_y_dev = test_coeffs_clean.to(device), test_y.to(device)

    if cfg.get('plot_data', False):
        plot_save_dir = os.path.join(cfg['plots_dir'], cfg['dataset_name'], experiment_name)
        add_time_flag = (str(cfg.get('add_time', 'yes')).lower() == 'yes')
        diagnostics = ExperimentDiagnostics(cfg, plot_save_dir, add_time_flag)
        diagnostics.run_suite(model, train_dataloader, device)

    best_val_accuracy = 0.0
    train_acc_at_best_val = 0.0 
    best_model_state = None
    
    avg_train_acc = 0.0 
    val_loss = 0.0
    val_accuracy = 0.0

    avg_nfe_final = 0.0 

    if hasattr(model, 'reset_fit_timer'):
        model.reset_fit_timer()

    training_loop_start = time.perf_counter()

    logger.info("Starting training and validation loop...")
    for epoch in range(1, cfg['num_epochs'] + 1):
        model.train()
        
        epoch_train_loss = 0.0
        epoch_train_correct = 0
        epoch_train_total = 0
        
        total_nfe = 0
        total_batches = 0
        
        for batch_idx, (batch_coeffs, batch_y) in enumerate(train_dataloader):
            batch_coeffs, batch_y = batch_coeffs.to(device), batch_y.to(device)
            optimizer.zero_grad()
            pred_y = model(batch_coeffs)
            loss = F.cross_entropy(pred_y, batch_y)
            loss.backward()
            optimizer.step()
            
            if hasattr(model, 'func'):
                vec_field = model.func
            elif hasattr(model, 'cde_func'):
                vec_field = model.cde_func
            else:
                vec_field = None
            
            global_step = (epoch - 1) * len(train_dataloader) + batch_idx
            
            current_nfe = 0
            if vec_field:
                current_nfe = getattr(vec_field, 'nfe', 0)
                tanh_sat = getattr(vec_field, 'last_tanh_saturation', 0.0)
                tanh_mean = getattr(vec_field, 'last_tanh_mean', 0.0)
                
                if global_step % 10 == 0:
                    writer.add_scalar('Debug/tanh_saturation_rate', tanh_sat, global_step)
                    writer.add_scalar('Debug/tanh_mean_abs', tanh_mean, global_step)
                    writer.add_scalar('Debug/NFE_batch', current_nfe, global_step)

            total_nfe += current_nfe
            total_batches += 1

            writer.add_scalar('Loss/train_batch', loss.item(), global_step)

            epoch_train_loss += loss.item()
            _, predicted = torch.max(pred_y.data, 1)
            epoch_train_total += batch_y.size(0)
            epoch_train_correct += (predicted == batch_y).sum().item()
        
        avg_train_loss = epoch_train_loss / len(train_dataloader)
        avg_train_acc = epoch_train_correct / epoch_train_total
        avg_nfe_epoch = total_nfe / total_batches if total_batches > 0 else 0
        avg_nfe_final = avg_nfe_epoch 

        writer.add_scalar('Loss/train_epoch', avg_train_loss, epoch)
        writer.add_scalar('Accuracy/train_epoch_percent', avg_train_acc * 100, epoch)
        writer.add_scalar('Debug/NFE_epoch_avg', avg_nfe_epoch, epoch)

        model.eval()
        with torch.no_grad():
            val_pred_y = model(val_coeffs_dev)
            val_loss = F.cross_entropy(val_pred_y, val_y_dev).item()
            _, val_predicted = torch.max(val_pred_y.data, 1)
            val_accuracy = (val_predicted == val_y_dev).sum().item() / val_y_dev.size(0)

            writer.add_scalar('Loss/val_epoch', val_loss, epoch)
            writer.add_scalar('Accuracy/val_epoch_percent', val_accuracy * 100, epoch)

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            train_acc_at_best_val = avg_train_acc
            best_model_state = copy.deepcopy(model.state_dict())

        if epoch % 5 == 0 or epoch == cfg['num_epochs']:
            log_msg = (f"Epoch: {epoch}/{cfg['num_epochs']} | "
                       f"Train Acc: {avg_train_acc*100:.2f}% | Val Acc: {val_accuracy*100:.2f}% | "
                       f"Avg NFE: {avg_nfe_epoch:.1f}")
            if vec_field and hasattr(vec_field, 'last_tanh_saturation'):
                 log_msg += f" | Saturation: {vec_field.last_tanh_saturation:.2f}"
            logger.info(log_msg)

    training_loop_end = time.perf_counter()
    total_loop_time = training_loop_end - training_loop_start

    dynamic_fit_time = getattr(model, 'fit_time_accum', 0.0)
    pure_training_time = total_loop_time - dynamic_fit_time
    total_fit_time = static_fit_time + dynamic_fit_time

    logger.info(f"Finished training.")
    logger.info(f"Total Loop Time: {total_loop_time:.2f}s")
    logger.info(f"  -> Dynamic Fit Overhead: {dynamic_fit_time:.2f}s")
    logger.info(f"  -> Pure Training Time: {pure_training_time:.2f}s")
    logger.info(f"Total Fit Time (Static + Dynamic): {total_fit_time:.2f}s")
    
    if best_model_state:
        model.load_state_dict(best_model_state)
    else:
        logger.warning("No best model was saved; evaluating the model from the final epoch.")
    
    model.eval()
    eval_start = time.perf_counter()
    with torch.no_grad():
        test_pred_y = model(test_coeffs_dev_clean)
        test_loss = F.cross_entropy(test_pred_y, test_y_dev).item()
        _, test_predicted = torch.max(test_pred_y.data, 1)
        test_accuracy = ((test_predicted == test_y_dev).sum().item() / test_y_dev.size(0)) * 100
    
    evaluation_time = time.perf_counter() - eval_start
    logger.info(f"Evaluation Time: {evaluation_time:.4f}s")

    logger.info(f"--- Finished Experiment: {experiment_name} ---")
    logger.info(f"Final Test Accuracy: {test_accuracy:.2f}% | Avg Solver Steps (NFE): {avg_nfe_final:.1f}")
    
    writer.add_scalar('Accuracy/test_final', test_accuracy, cfg['num_epochs'])
    
    noise_results = {}
    
    if cfg.get('test_noise_robustness', False):
        logger.info("\n" + "="*50)
        logger.info(">>> STARTING NOISE ROBUSTNESS EVALUATION <<<")
        logger.info("="*50)
        
        noise_levels = cfg.get('test_noise_levels', [0.0])
        
        for noise_std in noise_levels:
            noisy_test_X = test_X.clone()
            if noise_std > 0.0:
                noise = torch.randn_like(noisy_test_X) * noise_std
                noisy_test_X = noisy_test_X + noise
            
            inference_start = time.perf_counter()
            
            if interpolation_type == "cubic":
                noisy_coeffs = torchcde.hermite_cubic_coefficients_with_backward_differences(noisy_test_X)
            else:
                noisy_coeffs = noisy_test_X

            noisy_coeffs_dev = noisy_coeffs.to(device)


            if hasattr(model, 'func'): model.func.nfe = 0
            if hasattr(model, 'cde_func'): model.cde_func.nfe = 0
            
            with torch.no_grad():
                pred_y_noise = model(noisy_coeffs_dev)
                _, pred_lbl_noise = torch.max(pred_y_noise.data, 1)
                acc_noise = (pred_lbl_noise == test_y_dev).sum().item() / test_y_dev.size(0)
            
            inference_end = time.perf_counter()
            
            nfe_noise = 0
            if hasattr(model, 'func'): nfe_noise = model.func.nfe
            elif hasattr(model, 'cde_func'): nfe_noise = model.cde_func.nfe
            
            logger.info(f"[Noise Std: {noise_std}] Acc: {acc_noise*100:.2f}% | NFE: {nfe_noise}")
            
            noise_results[str(noise_std)] = {
                "noise_std": noise_std,
                "accuracy": acc_noise * 100,
                "nfe": nfe_noise,
                "inference_time": inference_end - inference_start
            }
    
    summary_text = (
        f"**Experiment:** {experiment_name}  \n"
        f"**Test Accuracy:** {test_accuracy:.2f}%  \n"
        f"**Avg NFE:** {avg_nfe_final:.1f}  \n"
        f"**Training Time:** {pure_training_time:.2f}s"
    )
    writer.add_text('Experiment/Summary', summary_text, 0)

    hparams = {
        "interpolation": interpolation_type,
        "batch_size": cfg['batch_size'],
        "epochs": cfg['num_epochs'],
        "weight_decay": cfg['weight_decay'],
        "time_scaling": cfg['time_scaling_factor'],
        "tolerance": str(kernel_params.get('tol', getattr(model, 'tol', 'N/A')))
    }
    if kernel_params:
        for key, value in kernel_params.items():
            if not isinstance(value, (int, float, str, bool)):
                hparams[f"param_{key}"] = str(value)
            else:
                hparams[f"param_{key}"] = value

    final_metrics = {
        "hparam/final_test_accuracy": test_accuracy,
        "hparam/avg_nfe": avg_nfe_final,
        "hparam/training_time_seconds": pure_training_time,
    }
    
    writer.add_hparams(hparams, final_metrics)
    writer.close()

    return {
        "test_accuracy": test_accuracy,
        "test_loss": test_loss,
        "final_train_accuracy": avg_train_acc * 100,
        "best_val_accuracy": best_val_accuracy * 100,
        "train_accuracy_at_best_val": train_acc_at_best_val * 100,
        
        "trajectories_fit_time": total_fit_time, 
        "training_time": pure_training_time,     
        "evaluation_time": evaluation_time,      
        "total_time": total_loop_time + static_fit_time,
        
        "avg_nfe": avg_nfe_final,
        "noise_robustness_results": noise_results if noise_results else None
    }