import torch
from models.modules.patching import patchify, unpatchify
from datasets.pretraining_dataset import PretrainingDataset
from datasets.hdf5_dataset import HDF5Loader
import hydra
import matplotlib.pyplot as plt
import numpy as np
import os

def get_figure():
    fig = plt.figure(num=0, figsize=(10, 6), dpi=300)
    fig.clf()
    return fig

def fig2rgb_array(fig, expand=False):
    fig.canvas.draw()
    buf = fig.canvas.tostring_rgb()
    ncols, nrows = fig.canvas.get_width_height()
    shape = (nrows, ncols, 3) #if not expand else (1, nrows, ncols, 3)
    return np.transpose(np.fromstring(buf, dtype=np.uint8).reshape(shape), (2, 0, 1))



# checkpoint = torch.load('/cluster/work/cvl/eeg_foundation/experiments/pretraining/checkpoints/simmim_minmax_data_p20/simmim_minmax_data_p20-mask0.7-epoch=07-val_loss=0.0198.ckpt')
# checkpoint = torch.load('/cluster/work/cvl/eeg_foundation/experiments/pretraining/checkpoints/simmim_minmax_data_p20/simmim_minmax_data_p20-mask0.7-epoch=25-val_loss=0.0141.ckpt')
# checkpoint = torch.load('/cluster/work/cvl/eeg_foundation/experiments/pretraining/checkpoints/simmim_scaled_data/simmim_scaled_data-mask0.7-epoch=39-val_loss=0.0255.ckpt')
# checkpoint = torch.load('/cluster/work/cvl/eeg_foundation/experiments/pretraining/checkpoints/vitmae_minmax_data_p64_loss_factor_4090_m50/chkp-epoch=45-val_loss=0.02.ckpt')
# checkpoint = torch.load('/cluster/work/cvl/eeg_foundation/experiments/pretraining/checkpoints/simmim_minmax_data_p20_loss_factor_4090_m50/simmim_minmax_data_p20_loss_factor_4090_m50-mask0.7-epoch=67-val_loss=0.0126.ckpt')
# checkpoint = torch.load('/cluster/work/cvl/eeg_foundation/experiments/pretraining/checkpoints/simmim_minmax_data_p64_loss_factor_4090_m50/simmim_minmax_data_p64_loss_factor_4090_m50-mask0.5-epoch=61-val_loss=0.0158.ckpt')
# checkpoint = torch.load('/cluster/work/cvl/eeg_foundation/experiments/pretraining/checkpoints/simmim_minmax_data_p20_loss_factor_4090_m50_correct/simmim_minmax_data_p20_loss_factor_4090_m50_correct-mask0.5-epoch=51-val_loss=0.0091.ckpt')
# checkpoint = torch.load('/cluster/work/cvl/eeg_foundation/experiments/pretraining/checkpoints/vitmae_minmax_data_p64_loss_factor_4090_m50/chkp-epoch=101-val_loss=0.02.ckpt')
# checkpoint = torch.load('/cluster/work/cvl/eeg_foundation/experiments/pretraining/checkpoints/vitmae_minmax_data_p20_loss_factor_4090_m50/chkp-epoch=45-val_loss=0.01.ckpt')
# checkpoint = torch.load('/cluster/work/cvl/eeg_foundation/experiments/pretraining/checkpoints/vitmae_minmax_data_p64_loss_factor_4090_smoothl1/chkp-epoch=51-val_loss=0.01.ckpt')
# checkpoint = torch.load('/cluster/work/cvl/eeg_foundation/experiments/pretraining/checkpoints/simmim_alternating_attn_p64_m50/simmim_alternating_attn_p64_m50-mask0.5-epoch=101-val_loss=0.0077.ckpt')
# checkpoint = torch.load('/cluster/work/cvl/eeg_foundation/experiments/pretraining/checkpoints/simmim_alternating_attn_p64_m50_4gpus/simmim_alternating_attn_p64_m50_4gpus-mask0.5-epoch=67-val_loss=0.0161.ckpt')
# checkpoint = torch.load('/cluster/work/cvl/eeg_foundation/experiments/pretraining/checkpoints/simmim_bottleneck_attn_p64_m50/simmim_bottleneck_attn_p64_m50-mask0.5-epoch=67-val_loss=0.0173.ckpt')
checkpoint = torch.load('/cluster/work/cvl/eeg_foundation/experiments/pretraining/checkpoints/simmim_alternating_attn_86M/simmim_alternating_attn_86M-mask0.5-epoch=73-val_loss=0.0000.ckpt')
jobname = 'simmim_alternating_attn_86M'
epoch = 73
plots_path = f'/cluster/work/cvl/eeg_foundation/experiments/pretraining/reconstructions/{jobname}/epoch{epoch}'
os.makedirs(plots_path, exist_ok=True)

model_hparams = checkpoint['hyper_parameters']['model']
model_head_hparams = checkpoint['hyper_parameters']['model_head']

state_dict = checkpoint['state_dict']
model_state_dict = {k.replace('model.', ''):state_dict[k] for k in state_dict.keys() if k.startswith('model.')  }
model_head_state_dict = {k.replace('model_head.', ''):state_dict[k] for k in state_dict.keys() if k.startswith('model_head.')  }

model = hydra.utils.instantiate(model_hparams)
model.load_state_dict(model_state_dict)

model_head = hydra.utils.instantiate(model_head_hparams)
model_head.load_state_dict(model_head_state_dict)

model.eval()
model_head.eval()


# TUSZ = PretrainingDataset('/cluster/work/cvl/eeg_foundation/TUH_sliced_datasets/TUSZ/')
# TUEP = PretrainingDataset('/cluster/work/cvl/eeg_foundation/TUH_sliced_datasets/TUEP/')
# TUSL = PretrainingDataset('/cluster/work/cvl/eeg_foundation/datasets/TUH_sliced_datasets/TUSL/')
TUSL = HDF5Loader(hdf5_file='/cluster/work/cvl/eeg_foundation/datasets/TUH_h5/TUSZ/le_files.h5', finetune=False)

# B, C, T = 1, 23, 1280
idx = 0

for idx in range(5):
    x = TUSL[idx]['input'].unsqueeze(0)
    latent, token_mask, ids_restore = model(x, mask_tokens=True)
    pred = model_head(latent, ids_restore)


    B,C,T = x.shape
    pred_unpatchified = unpatchify(pred, patch_size=model_hparams.patch_size, length=T, num_channels=C, keep_chans=model_hparams.keep_chans, using_spectrogram=False)
    pred_unpatchified = pred_unpatchified.detach().cpu().numpy()
    chan_0_pred = pred_unpatchified[0,0]
    chan_5_pred = pred_unpatchified[0,5]

  
    target_unpatchified = x.detach().cpu().numpy()
    chan_0_target = target_unpatchified[0,0]
    chan_5_target = target_unpatchified[0,5]
    
    token_mask = token_mask.unsqueeze(-1).repeat(1, 1, model_hparams.patch_size)
    mask_unpatchified = unpatchify(token_mask, patch_size=model_hparams.patch_size, length=T, num_channels=C, keep_chans=model_hparams.keep_chans, using_spectrogram=False)
    mask_unpatchified = mask_unpatchified.detach().cpu().numpy()
    chan_0_mask = mask_unpatchified[0,0]
    chan_5_mask = mask_unpatchified[0,5]
    
    
    
    fig = get_figure()
    plt.plot(chan_0_target, label='target', alpha=0.3, color='red')
    plt.plot(chan_0_pred, label='pred', alpha=0.5, color='blue')
    plt.fill_between(np.arange(len(chan_0_pred)), np.min(chan_0_target), np.max(chan_0_target), where=chan_0_mask > 0,
                color='gray', alpha=0.08, label='mask')
    plt.legend(bbox_to_anchor=(1, 1), loc=1, borderaxespad=0)
    plt.savefig(f'{plots_path}/plot0_{idx}.png')    
    
    fig = get_figure()
    plt.plot(chan_5_target, label='target', alpha=0.3, color='red')
    plt.plot(chan_5_pred, label='pred', alpha=0.5, color='blue')
    plt.fill_between(np.arange(len(chan_5_pred)), np.min(chan_5_target), np.max(chan_5_target), where=chan_5_mask > 0,
                    color='gray', alpha=0.08, label='mask')
    plt.legend(bbox_to_anchor=(1, 1), loc=1, borderaxespad=0)
    plt.savefig(f'{plots_path}/plot5_{idx}.png')
    
