import torch
import os
from pennylane import numpy as np
from tqdm.autonotebook import tqdm
import time
import matplotlib.pyplot as plt
from skimage.metrics import structural_similarity as ssim_metric
from skimage.metrics import peak_signal_noise_ratio as compute_psnr
from PIL import Image




TEST_PATH = "output/renders/"
MODEL_PATH = "output/models/"
STATS_PATH = "output/stats/"
#
def train(nerf_model,
          optimizer,
          scheduler,
          data_loader, 
          testing_dataset,
          number_in_test_set,
          device='cpu',
          hn=0,
          hf=1,
          nb_epochs=50,
          nb_bins=192,
          H=400,
          W=400,
          DEBUG=False,
          evaluate_each=10,
          save_each=10,
          USE_CLUSTER=True,
          seed=0,
          ):
    
    model_name = nerf_model.model_name
    model_path = os.path.join(MODEL_PATH, model_name, f'seed{seed}')
    stats_path = os.path.join(STATS_PATH, model_name, f'seed{seed}')
    os.makedirs(model_path, exist_ok=True)
    os.makedirs(stats_path, exist_ok=True)

    # ---------- Check for existing training state ----------
    existing_checkpoints = [f for f in os.listdir(model_path) if f.startswith("nerf_model_") and f.endswith(".pth")]
    if existing_checkpoints:
        # Get latest iteration number
        iterations = [int(f.split("_")[-1].split(".")[0]) for f in existing_checkpoints]
        latest_iter = max(iterations)

        print(f"Resuming training from epoch {latest_iter + 1}")

        # Load model, optimizer, scheduler
        nerf_model.load_state_dict(torch.load(os.path.join(model_path, f"nerf_model_{latest_iter}.pth"), map_location=device))
        optimizer.load_state_dict(torch.load(os.path.join(model_path, f"optimizer_{latest_iter}.pth"), map_location=device))
        scheduler.load_state_dict(torch.load(os.path.join(model_path, f"scheduler_{latest_iter}.pth"), map_location=device))

        # Load training stats
        stats = np.load(os.path.join(stats_path, f"training_stats_{latest_iter}.npz"))
        epochs = stats["epochs"].tolist()
        training_loss = stats["training_loss"].tolist()
        avgs_psnr = stats["avgs_psnr"].tolist()
        times_for_training = stats["times_for_training"].tolist()
        times_for_rendering = stats["times_for_rendering"].tolist()

        start_epoch = latest_iter + 1
    else:
        print("No existing checkpoint found. Starting fresh training.")
        epochs = []
        training_loss = []
        avgs_psnr = []
        times_for_training = []
        times_for_rendering = []
        start_epoch = 0

    print(f'H: {H}, W: {W}') if DEBUG else None

    for it in (range(start_epoch, nb_epochs) if USE_CLUSTER else tqdm(range(start_epoch, nb_epochs))):
        current_time_for_training = time.time()
        for batch in data_loader:
            ray_origins = batch[:, :3].to(device)
            ray_directions = batch[:, 3:6].to(device)
            ground_truth_px_values = batch[:, 6:].to(device)

            regenerated_px_values = render_rays(nerf_model, ray_origins, ray_directions, hn=hn, hf=hf, nb_bins=nb_bins)
            loss = ((ground_truth_px_values - regenerated_px_values) ** 2).sum()

            if not torch.isfinite(loss):
                print("Non-finite loss, that's bad")

            for name, param in nerf_model.named_parameters():
                if param.grad is not None and not torch.all(torch.isfinite(param.grad)):
                    print(f"Non-finite gradient in {name}")
                    continue


            optimizer.zero_grad()
            loss.backward()

            for name, param in nerf_model.named_parameters(): # Check for NaNs in gradients
                if param.grad is None:
                    print(f"{name}: grad is None")


            optimizer.step()
            training_loss.append(loss.item())

            if DEBUG:
                print(f'Epoch {it}, loss: {loss.item()}')
                break
        scheduler.step()
        current_time_for_training = time.time() - current_time_for_training
        times_for_training.append(current_time_for_training)

        if (it + 1) % evaluate_each == 0:
            test_path = os.path.join(TEST_PATH, model_name, f'seed{seed}', f'epoch_{it}')
            os.makedirs(test_path, exist_ok=True)
            print("\nTesting:") if DEBUG else None
            metrics = []
            current_time_for_rendering = time.time()
            for img_index in range(number_in_test_set):
                metrics.append(test(nerf_model,
                                hn,
                                hf,
                                testing_dataset, 
                                chunk_size=1,
                                img_index=img_index,
                                nb_bins=nb_bins,
                                H=H, W=W,
                                save_path=test_path,
                                device=device,
                                ))
            epochs.append(it+1)
            current_time_for_rendering = time.time() - current_time_for_rendering
            times_for_rendering.append(current_time_for_rendering)
            avg_psnr = np.mean(metrics)
            avgs_psnr.append(avg_psnr)

        if (it + 1) % save_each == 0:
            save_training_state(nerf_model, optimizer, scheduler, 
                training_loss, avgs_psnr, times_for_training, times_for_rendering, epochs,
                iteration=it, seed=seed)

        if USE_CLUSTER:
            print(f"Estimated time remaining: {np.mean(times_for_training) * (nb_epochs - it) / 60:.2f} minutes")

    return training_loss, avgs_psnr





def save_training_state(nerf_model,
                        optimizer,
                        scheduler,
                        training_loss,
                        avgs_psnr,
                        times_for_training,
                        times_for_rendering,
                        epochs,
                        iteration,
                        seed=0):  # <- NEW ARG

    """
    Save the model, optimizer state, scheduler state, and training statistics in a specified folder.

    Args:
        folder_path (str): The path where to save the training state.
        nerf_model (torch.nn.Module): The trained NeRF model.
        optimizer (torch.optim.Optimizer): The optimizer used during training.
        scheduler (torch.optim.lr_scheduler): The learning rate scheduler.
        training_loss (list): List of training losses over epochs.
        avgs_psnr (list): List of average PSNR values over epochs.
        times_for_training (list): List of training times per epoch.
        times_for_rendering (list): List of rendering times per evaluation.
    """
    model_name = nerf_model.model_name
    model_path = os.path.join(MODEL_PATH, nerf_model.model_name, f'seed{seed}')
    stats_path = os.path.join(STATS_PATH, nerf_model.model_name, f'seed{seed}')


    # Create folder if it does not exist
    os.makedirs(model_path, exist_ok=True)

    # Save model state
    torch.save(nerf_model.state_dict(), os.path.join(model_path, f'nerf_model_{str(iteration)}.pth'))
    
    # Save optimizer state
    torch.save(optimizer.state_dict(), os.path.join(model_path, f'optimizer_{str(iteration)}.pth'))

    # Save scheduler state
    torch.save(scheduler.state_dict(), os.path.join(model_path, f'scheduler_{str(iteration)}.pth'))

    os.makedirs(stats_path, exist_ok=True)
    # Save statistics as a numpy compressed file
    np.savez(os.path.join(stats_path, f'training_stats_{str(iteration)}.npz'),
             training_loss=np.array(training_loss),
             epochs=np.array(epochs),
             avgs_psnr=np.array(avgs_psnr),
             times_for_training=np.array(times_for_training),
             times_for_rendering=np.array(times_for_rendering))
    
    print(f"Training state saved in {model_path}, {stats_path}")



@torch.no_grad()
def test(nerf_model,
         hn,
         hf,
         dataset,
         chunk_size=10,
         img_index=0,
         nb_bins=192,
         H=400,
         W=400,
         save_path=None,
         device='cpu',
         DEBUG=False):
    # Extract reference image and convert to NumPy
    reference_img = dataset[img_index * H * W: (img_index + 1) * H * W, 6:].reshape(H, W, 3)
    reference_img = reference_img.cpu().numpy()  # Convert from tensor to NumPy array

    # Extract ray origins and directions
    ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
    ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]

    data = []
    for i in range(int(np.ceil(H / chunk_size))):
        ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
        ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)

        regenerated_px_values = render_rays(nerf_model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
        data.append(regenerated_px_values.cpu())

    # Combine chunks and convert to NumPy array
    img = torch.cat(data).numpy().reshape(H, W, 3)

    # Normalize image values to [0, 1]
    img = np.clip(img, 0, 1)
    #print(f'img shape: {img.shape}') if DEBUG else None
    #print(f'max: {np.max(img)}, min: {np.min(img)}') 

    # Compute PSNR
    psnr = compute_psnr(reference_img.astype(np.float32), img.astype(np.float32), data_range=1.0)
    #print(f'PSNR: {psnr:.2f} dB')

    # Save the image
    plt.figure()
    plt.imshow(img)
    plt.axis('off')
    plt.savefig(f'{save_path}/img_{img_index}.pdf', bbox_inches='tight')
    plt.close()

    return psnr


@torch.no_grad()
def smart_test(nerf_model,
               hn,
               hf,
               dataset,
               chunk_size=10,
               img_index=0,
               nb_bins=192,
               H=400,
               W=400,
               save_path=None,
               device='cpu',
               DEBUG=False):
    
    # Extract reference image and convert to NumPy
    reference_img = dataset[img_index * H * W: (img_index + 1) * H * W, 6:].reshape(H, W, 3)
    reference_img = reference_img.cpu().numpy()
    reference_img = np.clip(reference_img, 0, 1)

    # Extract ray origins and directions
    ray_origins = dataset[img_index * H * W: (img_index + 1) * H * W, :3]
    ray_directions = dataset[img_index * H * W: (img_index + 1) * H * W, 3:6]

    data = []
    for i in range(int(np.ceil(H / chunk_size))):
        ray_origins_ = ray_origins[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
        ray_directions_ = ray_directions[i * W * chunk_size: (i + 1) * W * chunk_size].to(device)
        regenerated_px_values = render_rays(nerf_model, ray_origins_, ray_directions_, hn=hn, hf=hf, nb_bins=nb_bins)
        data.append(regenerated_px_values.cpu())

    # Combine chunks
    img = torch.cat(data).reshape(H, W, 3).clamp(0, 1)

    # Compute PSNR
    psnr = compute_psnr(reference_img.astype(np.float32), img.numpy().astype(np.float32), data_range=1.0)

    # Convert to torch tensors for SSIM and LPIPS
    img_tensor = img.permute(2, 0, 1).unsqueeze(0)  # Shape: (1, 3, H, W)
    #ref_tensor = torch.from_numpy(reference_img).permute(2, 0, 1).unsqueeze(0).float()

    # Compute SSIM
    img_np_gray = np.mean(img.numpy(), axis=2)
    ref_np_gray = np.mean(reference_img, axis=2)
    ssim_val = ssim_metric(ref_np_gray, img_np_gray, data_range=1.0)


    # Save the image as PNG (no plt)
    if save_path:
        os.makedirs(save_path, exist_ok=True)
    # Convert tensor to NumPy and scale to [0, 255]
    img_np = (img_tensor.squeeze(0).permute(1, 2, 0).numpy() * 255).astype(np.uint8)
    # Save using PIL
    Image.fromarray(img_np).save(f'{save_path}/img_{img_index}.png')

    return {'psnr': psnr, 'ssim': ssim_val, 'lpips': 0}





def render_rays(nerf_model, ray_origins, ray_directions, hn=0, hf=0.5, nb_bins=192):
    device = ray_origins.device
    #print(f"device from render rays: {device}")
    t = torch.linspace(hn, hf, nb_bins, device=device).expand(ray_origins.shape[0], nb_bins)
    # Perturb sampling along each ray.
    mid = (t[:, :-1] + t[:, 1:]) / 2.
    lower = torch.cat((t[:, :1], mid), -1)
    upper = torch.cat((mid, t[:, -1:]), -1)
    u = torch.rand(t.shape, device=device)
    t = lower + (upper - lower) * u  # [batch_size, nb_bins]
    delta = torch.cat((t[:, 1:] - t[:, :-1], torch.tensor([1e5], device=device).expand(ray_origins.shape[0], 1)), -1)

    x = ray_origins.unsqueeze(1) + t.unsqueeze(2) * ray_directions.unsqueeze(1)   # [batch_size, nb_bins, 3]
    ray_directions = ray_directions.expand(nb_bins, ray_directions.shape[0], 3).transpose(0, 1)

    # check if x.reshape(-1, 3) is a number and if ray_directions.reshape(-1, 3) is a number
    if torch.isnan(x.reshape(-1, 3)).any():
        print('x.reshape(-1, 3) is nan')
        assert False
    if torch.isnan(ray_directions.reshape(-1, 3)).any():
        print('ray_directions.reshape(-1, 3) is nan')
        assert False

    colors, sigma = nerf_model(x.reshape(-1, 3), ray_directions.reshape(-1, 3))
    colors = colors.reshape(x.shape)
    sigma = sigma.reshape(x.shape[:-1])

    # for each component, evaluate which is nan
    if torch.isnan(colors).any():
        print('colors is nan')
        print("However, the values of sigma and delta are:")
        print(sigma)
        print(delta)
        print("The values of x are:")
        print(x)
        
        assert False
    if torch.isnan(sigma).any():
        print('sigma is nan')
        assert False
    

    alpha = 1 - torch.exp(-sigma * delta)  # [batch_size, nb_bins]
    # check if it is a number
    if torch.isnan(alpha).any():
        print('alpha is nan')
        print("However, the values of sigma and delta are:")
        print(sigma)
        print(delta)

        assert False
    weights = compute_accumulated_transmittance(1 - alpha).unsqueeze(2) * alpha.unsqueeze(2)
    c = (weights * colors).sum(dim=1)  # Pixel values
    if torch.isnan(c).any():
        print('c is nan')
        assert False
    weight_sum = weights.sum(-1).sum(-1)  # Regularization for white background
    if torch.isnan(weight_sum).any():
        print('weight_sum is nan')
        assert False
    return c + 1 - weight_sum.unsqueeze(-1)


def compute_accumulated_transmittance(alphas):
    accumulated_transmittance = torch.cumprod(alphas, 1)
    return torch.cat((torch.ones((accumulated_transmittance.shape[0], 1), device=alphas.device),
                      accumulated_transmittance[:, :-1]), dim=-1)