import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from typing import Optional, Tuple
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torch.cuda.amp as amp
import gc
from Renderer import NeuralRFOptimizer
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'codes'))
from dataloader import BLE_dataset, split_dataset


def rssi2amplitude(rssi):
    return 1 - (rssi / -100)


def amplitude2rssi(amplitude):
    return -100 * (1 - amplitude)


def train_neural_rf(
    num_epochs: int = 2000,
    learning_rate: float = 5e-3, 
    train_dataloader: DataLoader = None,
    num_rays_azimuth: int = 360,
    num_rays_elevation: int = 180,
    Rx_radius: float = 0.05,
    sample_spacing: float = 0.02,
    space_boundaries: torch.Tensor = torch.tensor([20.0, 20.0, 3.0], requires_grad=False, dtype=torch.float32).reshape(1, 3),
    device: str = 'cuda',
    use_soft_rendering: bool = False,
    warmup_epochs: int = 500,
    freeze_epochs: int = 1000,
    eikonal_weight: float = 1.0,
    laplacian_weight_fmean: float = 0.01,
    free_space_weight: float = 1.0
):
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    print(f"Training on device: {device}")
    
    model = NeuralRFOptimizer(
        structure_hidden_dim=64,
        material_hidden_dim=128,
        num_layers=8,
        max_steps=128,
        loss_type='huber',
        use_sphere_init=False,
        use_box_init=True,
        box_size=torch.tensor([17.0, 13.0, 3.0], device=device, dtype=torch.float32, requires_grad=False),
        box_center=torch.tensor([8.5, 8.0, 1.5], device=device, dtype=torch.float32, requires_grad=False),
        num_spheres=6,
        sphere_radius_range=(0.8, 1.0),
        space_boundaries=space_boundaries
    ).to(device)

    os.makedirs('init', exist_ok=True)
    torch.save(model.state_dict(), 'init/neural_rf_model.pth')

    test_points = torch.tensor([
        [8.5, 8.0, 1.5],
        [0.0, 1.5, 0.0],
        [17.0, 14.5, 3.0],
    ], device=device, dtype=torch.float32)

    space_boundaries = space_boundaries.to(device)

    with torch.no_grad():
        sdf_values, _ = model.renderer.fmean_network.structure_model(test_points/space_boundaries*2-1)
        print("SDF values:", sdf_values.squeeze())

    print("Extracting all Tx and Rx positions from dataset...")
    all_tx_positions = []
    all_rx_positions = []
    
    for batch in train_dataloader:
        positions, labels = batch
        batch_rx = positions[:, :3]
        batch_tx = positions[:, 3:6]
        all_tx_positions.append(batch_tx)
        all_rx_positions.append(batch_rx)
    
    all_tx_positions = torch.cat(all_tx_positions, dim=0).to(device)
    all_rx_positions = torch.cat(all_rx_positions, dim=0).to(device)
    
    print("Removing duplicate positions...")
    
    tx_np = all_tx_positions.cpu().numpy()
    rx_np = all_rx_positions.cpu().numpy()
    
    tx_rounded = np.round(tx_np, 3)
    rx_rounded = np.round(rx_np, 3)
    
    _, tx_unique_indices = np.unique(tx_rounded, axis=0, return_index=True)
    _, rx_unique_indices = np.unique(rx_rounded, axis=0, return_index=True)
    
    unique_tx_positions = all_tx_positions[tx_unique_indices]
    unique_rx_positions = all_rx_positions[rx_unique_indices]
    
    print(f"Original Tx positions: {all_tx_positions.shape[0]}")
    print(f"Unique Tx positions: {unique_tx_positions.shape[0]}")
    print(f"Original Rx positions: {all_rx_positions.shape[0]}")
    print(f"Unique Rx positions: {unique_rx_positions.shape[0]}")
    
    all_tx_positions = unique_tx_positions
    all_rx_positions = unique_rx_positions
    
    structural_optimizer = optim.Adam(model.renderer.fmean_network.structure_model.parameters(), lr=0.01*learning_rate, weight_decay=1e-5)
    material_optimizer = optim.Adam(model.renderer.fmean_network.material_model.parameters(), lr=learning_rate, weight_decay=1e-5)
    weights_optimizer = optim.Adam(model.renderer.weights_model.parameters(), lr=learning_rate, weight_decay=1e-5)
    
    material_warmup_scheduler = optim.lr_scheduler.LinearLR(
        material_optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_epochs
    )
    material_main_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        material_optimizer, T_max=num_epochs - warmup_epochs
    )
    
    structural_warmup_scheduler = optim.lr_scheduler.LinearLR(
        structural_optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_epochs
    )
    structural_main_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        structural_optimizer, T_max=num_epochs - warmup_epochs
    )
    weights_warmup_scheduler = optim.lr_scheduler.LinearLR(
        weights_optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_epochs
    )
    weights_main_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        weights_optimizer, T_max=num_epochs - warmup_epochs
    )
    
    losses = []
    RSSI_losses = []
    eikonal_losses = []
    free_space_losses = []
    laplacian_material_losses = []
    laplacian_anisotropy_losses = []
    laplacian_phase_losses = []
    
    pbar = tqdm(range(num_epochs), desc="Training")
    data_iter = iter(train_dataloader)

    for epoch in pbar:
        try:
            batch = next(data_iter)
        except StopIteration:
            data_iter = iter(train_dataloader)
            batch = next(data_iter)
        
        batch_loss = torch.tensor(0.0, device=device, dtype=torch.float32, requires_grad=True)
        batch_RSSI_loss = torch.tensor(0.0, device=device, dtype=torch.float32, requires_grad=False)
        batch_eikonal_loss = torch.tensor(0.0, device=device, dtype=torch.float32, requires_grad=False)
        batch_free_space_loss = torch.tensor(0.0, device=device, dtype=torch.float32, requires_grad=False)

        positions, labels = batch
        batch_size = positions.shape[0]
        
        space_boundaries = space_boundaries.to(device)
        
        structural_optimizer.zero_grad()
        material_optimizer.zero_grad()
        weights_optimizer.zero_grad()

        for i in range(batch_size):
            Rx_positions = positions[i, :3].to(device).unsqueeze(0)
            Tx_positions = positions[i, 3:6].to(device).unsqueeze(0)
            Tx_signals = positions[i, 6:7].to(device).unsqueeze(0)
            RSSI_values = labels[i].to(device).unsqueeze(0)
        
            use_soft = use_soft_rendering or (epoch < warmup_epochs // 2)
        
            loss_dict = model.compute_total_loss(
                Tx_positions,
                Rx_positions,
                Tx_signals,
                Rx_radius,
                num_rays_azimuth,
                num_rays_elevation,
                sample_spacing,
                RSSI_values,
                space_boundaries,
                use_soft,
                eikonal_weight,
                laplacian_weight_fmean,
                free_space_weight,
                all_tx_positions,
                all_rx_positions,
            )
        
            batch_loss = batch_loss + loss_dict['total']/batch_size
            batch_RSSI_loss = batch_RSSI_loss + loss_dict['RSSI'].detach()/batch_size
            batch_eikonal_loss = batch_eikonal_loss + loss_dict['eikonal'].detach()/batch_size
            batch_free_space_loss = batch_free_space_loss + loss_dict['free_space'].detach()/batch_size

            print(f'gt_RSSI: {RSSI_values.item()}, pred_RSSI: {loss_dict["pred_RSSI"].item()}')
            del loss_dict
            
            
        batch_loss.backward()
        
        if epoch < freeze_epochs:
            torch.nn.utils.clip_grad_norm_(model.renderer.fmean_network.material_model.parameters(), max_norm=1.0)
            material_optimizer.step()
            weights_optimizer.step()
        else:
            torch.nn.utils.clip_grad_norm_(model.renderer.fmean_network.structure_model.parameters(), max_norm=1.0)
            torch.nn.utils.clip_grad_norm_(model.renderer.fmean_network.material_model.parameters(), max_norm=1.0)
            structural_optimizer.step()
            material_optimizer.step()   
            weights_optimizer.step()
        
        if epoch < warmup_epochs:
            material_warmup_scheduler.step()
            weights_warmup_scheduler.step()
        else:
            material_main_scheduler.step()
            weights_main_scheduler.step()
            
        if epoch >= freeze_epochs:
            if epoch < warmup_epochs + freeze_epochs:
                structural_warmup_scheduler.step()
            else:
                structural_main_scheduler.step()
        
        losses.append(batch_loss.item())
        RSSI_losses.append(batch_RSSI_loss.item())
        eikonal_losses.append(batch_eikonal_loss.item())
        free_space_losses.append(batch_free_space_loss.item())
        
        if epoch+1 < 10000:
            if (epoch + 1) % 100 == 0:
                checkpoint_path = f'results/neural_rf_model_epoch_{epoch+1}.pth'
                torch.save(model.state_dict(), checkpoint_path)
                print(f"Model saved at epoch {epoch+1}: {checkpoint_path}")
        else:
            if (epoch + 1) % 1000 == 0:
                checkpoint_path = f'results/neural_rf_model_epoch_{epoch+1}.pth'
                torch.save(model.state_dict(), checkpoint_path)
                print(f"Model saved at epoch {epoch+1}: {checkpoint_path}")
        
        current_lr = material_optimizer.param_groups[0]['lr']
        pbar.set_postfix({  
            'RSSI_loss': f"{batch_RSSI_loss.item():.4f}",
            'eikonal': f"{batch_eikonal_loss.item():.4f}",
            'free_space': f"{batch_free_space_loss.item():.4f}",
        })
        
    os.makedirs('results', exist_ok=True)
    torch.save(model.state_dict(), 'results/neural_rf_model.pth')
    
    return model

def evaluate_neural_rf(
    model: NeuralRFOptimizer,
    test_dataloader: DataLoader,
    num_rays_azimuth: int = 360,
    num_rays_elevation: int = 180,
    Rx_radius: float = 0.5,
    sample_spacing: float = 0.01,
    device: str = 'cuda',
    use_soft_rendering: bool = False,
    load_from_file: bool = True,
    space_boundaries: torch.Tensor = torch.tensor([20.0, 20.0, 3.0], requires_grad=False, dtype=torch.float32).reshape(1, 3)
):
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    print(f"Evaluating on device: {device}")
    
    if load_from_file:
        model.load_state_dict(torch.load('results/neural_rf_model.pth'))
    space_boundaries = space_boundaries.to(device)
    
    model = model.to(device)
    model.eval()
    
    all_pred_RSSI = []
    all_true_RSSI = []
    
    for batch in test_dataloader:
        positions, labels = batch
        Rx_positions = positions[:, :3].to(device)
        Tx_positions = positions[:, 3:6].to(device)
        Tx_signals = positions[:, 6:7].to(device)
        RSSI_values = labels.to(device)
    
        pred_RSSI = model.forward(
            Tx_positions,
            Rx_positions,
            Tx_signals,
            Rx_radius,
            num_rays_azimuth,
            num_rays_elevation,
            sample_spacing,
            use_soft_rendering
        )
        
        all_pred_RSSI.append(pred_RSSI.detach().cpu())
        all_true_RSSI.append(RSSI_values.detach().cpu())
    
    all_pred_RSSI = torch.cat(all_pred_RSSI, dim=0)

    all_true_RSSI = torch.cat(all_true_RSSI, dim=0)
    
    mse_loss = F.mse_loss(all_pred_RSSI, all_true_RSSI)
    
    mae_loss = F.l1_loss(all_pred_RSSI, all_true_RSSI)
    rmse_loss = torch.sqrt(mse_loss)
    
    print(f"Test Results (All Data):")
    print(f"  Total samples: {len(all_true_RSSI)}")
    print(f"  MSE Loss: {mse_loss.item():.6f}")
    print(f"  MAE Loss: {mae_loss.item():.6f}")
    print(f"  RMSE Loss: {rmse_loss.item():.6f}")
    
    return {
        'mse_loss': mse_loss.item(),
        'mae_loss': mae_loss.item(),
        'rmse_loss': rmse_loss.item(),
        'total_samples': len(all_true_RSSI),
        'predictions': all_pred_RSSI,
        'ground_truth': all_true_RSSI
    }
    
def check_gradient_nan(model, prefix=""):
    has_nan_grad = False
    nan_params = []
    
    for name, param in model.named_parameters():
        if param.grad is not None:
            if torch.isnan(param.grad).any():
                print(f"❌ {prefix}NaN gradient detected in {name}")
                print(f"   Gradient shape: {param.grad.shape}")
                print(f"   Gradient stats: min={param.grad.min():.6f}, max={param.grad.max():.6f}")
                print(f"   NaN count: {torch.isnan(param.grad).sum().item()}")
                has_nan_grad = True
                nan_params.append(name)
            elif torch.isinf(param.grad).any():
                print(f"⚠️  {prefix}Inf gradient detected in {name}")
                has_nan_grad = True
                nan_params.append(name)
    
    if not has_nan_grad:
        print(f"✅ {prefix}No NaN/Inf gradients found")
    
    return has_nan_grad, nan_params

def main():
    current_dir = os.path.dirname(os.path.abspath(__file__))
    project_root = os.path.dirname(current_dir)
    data_root = os.path.join(project_root, "Dataset", "BLE", "rssi-dataset-1", "rssi-dataset-1")
    
    train_dataset = BLE_dataset(
        datadir=data_root,
        indexdir=os.path.join(data_root, "train_index_1_4.txt"),
        max_rows=1565,
        x_range=(5, 25),
        y_range=(15, 35),
        z_range=(-1, 2),
        shuffle_data=True
    )
    train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
    
    model = train_neural_rf(
        train_dataloader=train_dataloader,
        num_epochs=15000,
        learning_rate=1e-4,
        num_rays_azimuth=36,
        num_rays_elevation=18,
        Rx_radius=0.5,
        sample_spacing=0.02,
        space_boundaries=torch.tensor([20.0, 20.0, 3.0], requires_grad=False, dtype=torch.float32).reshape(1, 3),
        device='cuda:0'
    )

if __name__ == "__main__":
    main() 