import os
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
import argparse
from utils.data_utils import get_sequential_data_loaders
from models.splr_model import SPLRModel

def plot_loss_accuracy(log_dir, save_path=None):
    """
    Function to plot training and validation loss and accuracy over epochs using TensorBoard logs.

    Args:
        log_dir (str): Directory where TensorBoard logs are stored.
        save_path (str): Path to save the plotted figures. If None, figures are just displayed.
    """
    # SummaryWriter to read TensorBoard logs
    writer = SummaryWriter(log_dir)
    
    # Extract metrics from TensorBoard logs
    training_loss = writer.scalar_dict.get("Train/Loss", {})
    validation_loss = writer.scalar_dict.get("Test/Loss", {})
    training_accuracy = writer.scalar_dict.get("Train/Accuracy", {})
    validation_accuracy = writer.scalar_dict.get("Test/Accuracy", {})

    epochs = list(training_loss.keys())

    # Plot training and validation loss
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(epochs, list(training_loss.values()), label='Training Loss', color='blue')
    plt.plot(epochs, list(validation_loss.values()), label='Validation Loss', color='orange')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()

    # Plot training and validation accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, list(training_accuracy.values()), label='Training Accuracy', color='blue')
    plt.plot(epochs, list(validation_accuracy.values()), label='Validation Accuracy', color='orange')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Training and Validation Accuracy')
    plt.legend()

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()

def visualize_predictions(model_path, data_loader, device, num_images=16, save_path=None):
    """
    Function to visualize predictions of a trained model on a dataset.

    Args:
        model_path (str): Path to the saved model checkpoint.
        data_loader: DataLoader instance for the dataset to visualize.
        device (str): Device to use for running the model (e.g., "cpu" or "cuda:0").
        num_images (int): Number of images to visualize.
        save_path (str): Path to save the visualization. If None, the visualization is just displayed.
    """
    # Load the model
    model = SPLRModel(
        input_channels=3,
        num_classes=10,  # Assuming CIFAR-10, change if necessary
        decay_rate=0.1,
        tau_d_list=[2.0, 5.0, 10.0],
        tau_s=5.0
    )
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model'])
    model.to(device)
    model.eval()

    # Prepare data for visualization
    inputs, labels = next(iter(data_loader))
    inputs, labels = inputs.to(device), labels.to(device)

    with torch.no_grad():
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)

    # Convert back to CPU for visualization
    inputs = inputs.cpu()
    predicted = predicted.cpu()
    labels = labels.cpu()

    # Reverse normalization to visualize correctly
    mean = torch.tensor([0.4914, 0.4822, 0.4465])
    std = torch.tensor([0.2023, 0.1994, 0.2010])
    unnormalize = transforms.Normalize((-mean / std).tolist(), (1.0 / std).tolist())

    inputs = torch.stack([unnormalize(img) for img in inputs])

    # Display the images with true and predicted labels
    fig, axes = plt.subplots(1, num_images, figsize=(15, 5))
    for i in range(num_images):
        img = inputs[i]
        img = img.permute(1, 2, 0)  # Change shape to (height, width, channels)
        axes[i].imshow(img.numpy())
        axes[i].set_title(f'True: {labels[i].item()}, Pred: {predicted[i].item()}')
        axes[i].axis('off')

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()

def visualize_images(dataset_name, data_dir, class_num, batch_size=16, num_steps=10, save_path=None):
    """
    Visualize a batch of sequential CIFAR images to understand the sequential representation.

    Args:
        dataset_name (str): The name of the dataset ('cifar10' or 'cifar100').
        data_dir (str): Directory where CIFAR datasets are stored.
        class_num (int): Number of classes in the dataset (10 for CIFAR10, 100 for CIFAR100).
        batch_size (int): Number of images in the batch.
        num_steps (int): Number of steps (frames) in the sequence for each sample.
        save_path (str): Path to save the visualization. If None, the visualization is just displayed.
    """
    parser_args = argparse.Namespace(
        data_dir=data_dir,
        class_num=class_num,
        b=batch_size,
        j=4,
        device="cpu"
    )
    train_loader, _ = get_sequential_data_loaders(parser_args, num_steps=num_steps)

    inputs, _ = next(iter(train_loader))
    inputs = inputs.permute(1, 0, 2, 3, 4)  # Rearrange to [num_steps, batch_size, channels, height, width]

    fig, axes = plt.subplots(num_steps, batch_size, figsize=(20, 10))
    for step in range(num_steps):
        for img_index in range(batch_size):
            img = inputs[step][img_index]
            img = img.permute(1, 2, 0)  # Change shape to (height, width, channels)
            axes[step, img_index].imshow(img.numpy())
            axes[step, img_index].axis('off')
            if step == 0:
                axes[step, img_index].set_title(f'Sample {img_index + 1}')

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Visualize training, testing metrics, and model predictions')
    parser.add_argument('--mode', type=str, required=True, choices=['loss_accuracy', 'predictions', 'images'],
                        help='Mode of visualization: "loss_accuracy", "predictions", or "images"')
    parser.add_argument('--log-dir', type=str, help='Directory where TensorBoard logs are stored')
    parser.add_argument('--model-path', type=str, help='Path to the saved model checkpoint for prediction visualization')
    parser.add_argument('--data-dir', type=str, help='Directory where CIFAR datasets are stored')
    parser.add_argument('--output-dir', type=str, help='Directory to save the visualization', default='./visualizations')
    parser.add_argument('--dataset-name', type=str, help='Dataset name ("cifar10" or "cifar100")')
    parser.add_argument('--class-num', type=int, help='Number of classes (10 or 100)', default=10)
    parser.add_argument('--num-images', type=int, default=16, help='Number of images to visualize for predictions')
    parser.add_argument('--batch-size', type=int, default=16, help='Batch size for data loading')
    parser.add_argument('--num-steps', type=int, default=10, help='Number of steps in sequence for sequential visualization')

    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)

    if args.mode == 'loss_accuracy':
        if args.log_dir is None:
            raise ValueError("Please provide --log-dir for TensorBoard logs.")
        save_path = os.path.join(args.output_dir, 'loss_accuracy.png')
        plot_loss_accuracy(args.log_dir, save_path)

    elif args.mode == 'predictions':
        if args.model_path is None or args.data_dir is None:
            raise ValueError("Please provide --model-path and --data-dir for prediction visualization.")
        _, test_loader = get_sequential_data_loaders(args, num_steps=args.num_steps)
        save_path = os.path.join(args.output_dir, 'predictions.png')
        visualize_predictions(args.model_path, test_loader, args.device, args.num_images, save_path)

    elif args.mode == 'images':
        if args.data_dir is None:
            raise ValueError("Please provide --data-dir to visualize images.")
        save_path = os.path.join(args.output_dir, 'sequential_images.png')
        visualize_images(args.dataset_name, args.data_dir, args.class_num, args.batch_size, args.num_steps, save_path)
