# Standard imports
import argparse
import gc
import os
from pathlib import Path
from tqdm import tqdm
from sys import exit
import wandb
from matplotlib import pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn.utils
from torch.nn.utils import weight_norm
from functools import partial

# =======
# LOGGING
# =======

def log(key, val, use_wandb=False):
    print(f"{key}: {val}")
    if use_wandb:
        wandb.log({key: val})


# ============
# VISUALIZATION
# ============

def plot_output(data, output, label, data_is_onehot=False, step=0, vocab_size=10, log=True):
    """
    data: Tensor of shape (T, N) if data_is_onehot, else (T,1)
    output: Tensor of shape (T, N) if data_is_onehot, else (T,1)
    label: Tensor of shape (T, 1)
    data_is_onehot: bool, indcating if data is one-hot (True) or a scalar (False)
    step: int or None, this is just for Wandb logging purposes
    vocab_size: int, range of possible values in copy task

    Output: wandb image of plots if step is not None, otherwise plt.show() 
    the images
    """

    if data_is_onehot:
      data = data.detach().T.cpu().numpy()
      output = F.softmax(output, dim=-1).detach().T.cpu().numpy()
      label = F.one_hot(label.type(torch.int64), vocab_size).float().squeeze().detach().T.cpu().numpy()
    else:
      data = data.detach().T.cpu().numpy()
      output = output.detach().T.cpu().numpy()
      label = label.detach().T.cpu().numpy()

    # clip sequence if it's too long to reasonably visualize
    if data.shape[1] > 50:
        data = data[:, :50]
        output = output[:, -50:]
        label = label[:, -50:]
        title_postfix = '_50-sample'
    else:
        title_postfix = '_all'

    f, axs = plt.subplots(3)
    f.set_size_inches(10, 6)

    if data_is_onehot:
      vmin=0
      vmax=1
    else:
      vmin=0
      vmax=vocab_size-1

    img1 = axs[0].imshow(data, cmap='PuBu_r', vmin=vmin, vmax=vmax)
    axs[0].get_xaxis().set_visible(False)
    axs[0].get_yaxis().set_visible(False)
    axs[0].set_title('Input' + title_postfix)

    img2 = axs[1].imshow(output, cmap='PuBu_r', vmin=vmin, vmax=vmax)
    axs[1].get_xaxis().set_visible(False)
    axs[1].get_yaxis().set_visible(False)
    axs[1].set_title('Output' + title_postfix)

    img2 = axs[2].imshow(label, cmap='PuBu_r', vmin=vmin, vmax=vmax)
    axs[2].get_xaxis().set_visible(False)
    axs[2].get_yaxis().set_visible(False)
    axs[2].set_title('Label' + title_postfix)

    # Optionally log to Weights and Biases:
    if not log:
      plt.show()
    else:
      wandb.log({"Output" : wandb.Image(plt)}, step=step, commit=True)
      plt.close('all')

def plot_3d_trajectory(trajectory, ax=None, title="3D Trajectory", xlabel="X", ylabel="Y", zlabel="Z"):
    if isinstance(trajectory, torch.Tensor):
        trajectory = trajectory.detach().cpu().numpy()

    if ax is None:
        import matplotlib.pyplot as plt
        from mpl_toolkits.mplot3d import Axes3D
        fig = plt.figure(figsize=(10, 8))
        ax = fig.add_subplot(111, projection='3d')

    ax.plot(trajectory[:, 0], trajectory[:, 1], trajectory[:, 2], lw=2)

    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_zlabel(zlabel)
    
    return ax

def plot_one_hot(one_hot, ax=None, title="One-Hot Encoding", xlabel="Time Step", ylabel="Value"):
    """
    Plots a one-hot encoded sequence.

    Args:
        one_hot (torch.Tensor or np.ndarray): The one-hot encoded sequence. Expected shape: (N, C).
        ax (matplotlib.axes.Axes, optional): Axes to plot on. If None, creates a new figure.
        title (str): Title of the plot.
        xlabel (str): Label for the x-axis.
        ylabel (str): Label for the y-axis.
    """
    if isinstance(one_hot, torch.Tensor):
        one_hot = one_hot.detach().cpu().numpy()

    if ax is None:
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(figsize=(10, 4))

    ax.imshow(one_hot.T, aspect='auto', cmap='grey', interpolation='nearest')
    ax.set_title(title)
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.grid(False)
    
    return ax

####



def plot_prediction_vs_label(prediction, label, title="Prediction vs Label", xlabel="Time Step", ylabel="Value"):
    """
    Plots the prediction and label time series.

    Args:
        prediction (torch.Tensor or np.ndarray): The predicted values. Expected shape: (N, 1) or (N,).
        label (torch.Tensor or np.ndarray): The true values. Expected shape: (N, 1) or (N,).
        title (str): Title of the plot.
        xlabel (str): Label for the x-axis.
        ylabel (str): Label for the y-axis.
    """

    # Convert to numpy arrays if they are torch tensors
    if isinstance(prediction, torch.Tensor):
        prediction = prediction.detach().cpu().numpy()
    if isinstance(label, torch.Tensor):
        label = label.detach().cpu().numpy()

    # Flatten in case they have shape (N, 1)
    prediction = np.squeeze(prediction)
    label = np.squeeze(label)

    plt.figure(figsize=(10, 4))
    plt.plot(label, linewidth=2)
    plt.plot(prediction, linewidth=2, linestyle="--")
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.legend()
    plt.grid(True)
    plt.show()

def normalize_int(x):
  x -= x.min()
  x *= 1000 / x.max()
  return x.int()

def Plot_Weight(weight, max_c=16):
    n_cin, n_hid = weight.shape
    # Seq shape should be T,C,H*W

    plt.figure(figsize = (10,4))
    plt.imshow(weight.T)
    plt.colorbar(fraction=0.039, pad=0.04)
    wandb.log({f'weight_matrix': wandb.Image(plt)})
    # plt.show()
    plt.close('all')

def Plot_Seq(seq, name='Latents_Static', max_c=16, log=True, step=0):
    n_t, n_cin, n_hid = seq.shape
    # Seq shape should be T,C,H*W

    plot_c = min(n_cin, max_c)

    for c in range(plot_c):
      # seq_norm = normalize_int(seq[:, c:c+1]).cpu()
      seq_norm = seq[:,c:c+1]
      # plt.figure(figsize = (10,4))
      plt.imshow(seq_norm.squeeze().T, origin='upper')
      plt.colorbar(fraction=0.039, pad=0.04)
      plt.ylabel('Space (Hidden Units)')
      plt.xlabel('Time (sequence steps)')
      plt.title("Latent State")
      if log:
        wandb.log({name + f'Channel {c}': wandb.Image(plt)}, step=step)
      else:
        plt.show()
      plt.close('all')

def plot_prediction(data, out, label, mem_len=1, i=0, log=True, one_hot=False):
  fig, axs = plt.subplots(3, 1, figsize=(16, 9*mem_len))
  axs[0].imshow(data[:mem_len,i,:].T,vmax=1.0)
  axs[0].set_title('Data')
  axs[0].set_yticks([])
  axs[0].set_xticks(range(0,mem_len,2))
  axs[1].imshow(out[-mem_len:,i,:].T, vmin=0, vmax=data.shape[-1])
  axs[1].set_title('Prediction')
  axs[1].set_yticks([])
  axs[1].set_xticks(range(0,mem_len,2))
  if len(label.shape)==3:
    axs[2].imshow(label[-mem_len:,i,:].T,vmax=1.0)
  else:
    axs[2].imshow(label[-mem_len:,i].unsqueeze(0),vmin=0,vmax=data.shape[-1])
  axs[2].set_title('Label')
  axs[2].set_yticks([])
  axs[2].set_xticks(range(0,mem_len,2))
  if log:
     wandb.log({f'Predictions': wandb.Image(plt)})
  else:
    plt.show()
  plt.close('all')


import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D  # This import registers the 3D projection, even if not used directly
import torch
import numpy as np

def plot_3d_prediction_vs_label(prediction, label, title="3D Lorenz Attractor: Prediction vs Label"):
    """
    Plots the 3D trajectories for prediction and label (true states).

    Args:
        prediction (torch.Tensor or np.ndarray): Predicted states, expected shape (N, 3).
        label (torch.Tensor or np.ndarray): Ground truth states, expected shape (N, 3).
        title (str): Title of the plot.
    """

    # Convert to numpy arrays if necessary
    if isinstance(prediction, torch.Tensor):
        prediction = prediction.detach().cpu().numpy()
    if isinstance(label, torch.Tensor):
        label = label.detach().cpu().numpy()

    # Create a 3D plot
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    # rotate view

    # Plot the label trajectory (ground truth)
    ax.plot(label[:, 0], label[:, 1], label[:, 2],
            label="True", color='blue', lw=2)

    # Plot the predicted trajectory
    ax.plot(prediction[:, 0], prediction[:, 1], prediction[:, 2],
            label="Prediction", color='red', lw=2, linestyle='--')

    ax.legend()
    plt.tight_layout()
    plt.show()

def visualize_system(system_type, seq_length=5000):
    """Visualize a trajectory from the specified system"""
    # Generate data
    #data, labels = data_fn_dict[system_type](batch_size=10, seq_length=seq_length)
    print(labels.shape)

    # Convert to numpy for plotting
    if isinstance(data, torch.Tensor):
        data = data.detach().cpu().numpy()

    # Create 3D plot
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')

    # Plot trajectory
    ax.plot(labels[:, 0, 0], labels[:, 0, 1], labels[:, 0, 2])

    # Add title and labels
    ax.set_title(f"{system_type.capitalize()} System")
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')

    plt.show()

    # Print statistics
    print(f"Data shape: {data.shape}")
    print(f"Data range: [{data.min():.2f}, {data.max():.2f}]")
    print(f"Data mean: {data.mean():.2f}")
    print(f"Data std: {data.std():.2f}")

# this is such a silly function
def str_to_bool(s):
    if s.lower() == 'true' or s.lower() == 'yes':
         return True
    elif s.lower() == 'false' or s.lower() == 'no':
         return False
    else:
         raise ValueError
