from typing import List
import einops
import numpy as np
import torch.distributed as dist
from functools import partial
import random
import lightning.pytorch as pl
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import wandb

from src.utils.distributed_utils import convert_and_to, get_rank0_device
from src.utils.permutations import unpermute_frame
from src.utils.plotter_2d import plot_frames_2d_batched


class PlottingCallback(pl.Callback):
    def __init__(self, 
                 every_n_steps: int,
                 n_steps: int,
                 n_samples: int,
                 dataset: Dataset,
                 devices: int | List[int],
                 ):
        super().__init__()
        self.every_n_steps = every_n_steps
        self.n_steps = n_steps
        self.n_samples = n_samples
        self.devices = devices
        # self.highlight_idcs = highlight_idcs if highlight_idcs is not None else []
        
        self.train_dataset = dataset(
            split='train',
            n_steps=n_steps,
            do_cache=False
        )
        val_dataset = dataset(
            split='val',
            n_steps=n_steps,
            do_cache=False
        )
        self.train_loader = DataLoader(self.train_dataset, batch_size=1, shuffle=True, num_workers=1)
        self.val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True, num_workers=1)
        
        # Initialize iterators
        self.train_iter = iter(self.train_loader)
        self.val_iter = iter(self.val_loader)

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        # TODO: dist.get_rank does not work; workaround below
        # might need to fix this for multi-node training
        # rank = dist.get_rank() if dist.is_initialized() else 0
        rank = batch['field'].device.index
        if rank != min(trainer.device_ids):
            return
        
        if (trainer.global_step % self.every_n_steps == 0):
            with torch.no_grad():
                self.infer_and_plot('train', batch['field'].device, pl_module)
                self.infer_and_plot('val', batch['field'].device, pl_module)

    def infer_and_plot(self, split, device, pl_module):

        batch = self.load_batch(split)
        batch = convert_and_to(batch, device)

        sol = pl_module.orchestrator.inference(batch)

        pred = sol.cpu().numpy()
        true = batch['field'].cpu().numpy()

        fig = plot_norms(pred, true, 
                         show_n=self.n_samples, 
                         plot_pts=batch['probe_idcs'].squeeze(0).cpu().numpy(),
                         title=f'T: {true.shape[1]}')
        wandb.log({f"{split}/plot_inference": wandb.Image(fig)})
        plt.close(fig)
        
    
    def load_batch(self, split):
        loader_iter = self.train_iter if split == 'train' else self.val_iter

        try:
            batch = next(loader_iter)
        except StopIteration:
            loader_iter = iter(self.train_loader if split == 'train' else self.val_loader)
            batch = next(loader_iter)
            if split == 'train':
                self.train_iter = loader_iter
            else:
                self.val_iter = loader_iter

        return batch



def plot_norms(pred, true, show_n=4, plot_pts=None, title=None):
    # Rearrange the tensors
    pred = einops.rearrange(pred, 'b n h w d -> (b n) h w d')
    true = einops.rearrange(true, 'b n h w d -> (b n) h w d')

    # Compute norms
    true_norms = np.linalg.norm(true, axis=-1)
    pred_norms = np.linalg.norm(pred, axis=-1)
    diff_norms = np.linalg.norm(true - pred, axis=-1)

    # Determine the number of rows to show
    bT = min(show_n, true_norms.shape[0])

    # Create a figure with 3 columns and bT rows, smaller size
    fig, axes = plt.subplots(bT, 3, figsize=(16, 4*bT))

    # If there's only one row, axes will be 1D; reshape for consistency
    if bT == 1:
        axes = axes[np.newaxis, :]

    # Plot true_norms, pred_norms, and diff_norms for the first show_n rows
    # for i in range(bT):
    for i in range(pred_norms.shape[0]):
        im0 = axes[i, 0].imshow(true_norms[i], cmap='viridis')
        axes[i, 0].set_title(f'True Norms {i+1}')
        axes[i, 0].axis('off')

        im1 = axes[i, 1].imshow(pred_norms[i], cmap='viridis')
        axes[i, 1].set_title(f'Pred Norms {i+1}')
        axes[i, 1].axis('off')

        im2 = axes[i, 2].imshow(diff_norms[i], cmap='plasma')
        axes[i, 2].set_title(f'Diff Norms {i+1}')
        axes[i, 2].axis('off')

    # Adjust layout to make room for colorbars
    plt.subplots_adjust(right=0.8, wspace=0.3)

    # Add colorbars to the right of the figure, spaced apart
    cbar_ax1 = fig.add_axes([0.83, 0.15, 0.02, 0.7])
    cbar_ax2 = fig.add_axes([0.88, 0.15, 0.02, 0.7])
    fig.colorbar(im0, cax=cbar_ax1, label='True/Pred Norms')
    fig.colorbar(im2, cax=cbar_ax2, label='Diff Norms')

    if plot_pts is not None:
        for i, ax in enumerate(axes.flatten()):
            ax.scatter(plot_pts[:,0], plot_pts[:,1], marker='x', color='grey', s=20, linewidths=1)

    if title is not None:
        fig.suptitle(title)

    # plt.tight_layout()
    # plt.show()
    return fig
