import av
import torch
import numpy as np

from transformers import AutoProcessor, XCLIPVisionModel, XCLIPModel
from huggingface_hub import hf_hub_download
# Imports
import os
import glob
import pickle

import numpy as np
import matplotlib.pyplot as plt
import nibabel as nib

from nilearn import datasets
from nilearn import plotting
import tqdm
from modeling_git import GitForCausalLM, GitModel, GitForCausalLMClipEmb
from sklearn.model_selection import train_test_split

from sklearn.linear_model import Ridge, RidgeCV

from IPython.display import HTML
from base64 import b64encode
from joblib import Parallel, delayed
import h5py
import numpy as np
import matplotlib.pyplot as plt
import os
from os.path import join as opj
# import textgrids
import json
import tqdm
from sklearn.model_selection import train_test_split
import torch
import wandb
from sentence_transformers import SentenceTransformer
from joblib import Parallel, delayed
from sklearn.linear_model import RidgeCV
import numpy as np


# fsaverage = datasets.fetch_surf_fsaverage()
np.random.seed(0)


def read_video_pyav(container, indices):
    '''
    Decode the video with PyAV decoder.
    Args:
        container (`av.container.input.InputContainer`): PyAV container.
        indices (`List[int]`): List of frame indices to decode.
    Returns:
        result (np.ndarray): np array of decoded frames of shape (num_frames, height, width, 3).
    '''
    frames = []
    container.seek(0)
    start_index = indices[0]
    end_index = indices[-1]
    for i, frame in enumerate(container.decode(video=0)):
        if i > end_index:
            break
        if i >= start_index and i in indices:
            frames.append(frame)
    return np.stack([x.to_ndarray(format="rgb24") for x in frames])


def sample_frame_indices(clip_len, frame_sample_rate, seg_len):
    '''
    Sample a given number of frame indices from the video.
    Args:
        clip_len (`int`): Total number of frames to sample.
        frame_sample_rate (`int`): Sample every n-th frame.
        seg_len (`int`): Maximum allowed index of sample's last frame.
    Returns:
        indices (`List[int]`): List of sampled frame indices
    '''
    converted_len = int(clip_len * frame_sample_rate)
    end_idx = np.random.randint(converted_len, seg_len)
    start_idx = end_idx - converted_len
    indices = np.linspace(start_idx, end_idx, num=clip_len)
    indices = np.clip(indices, start_idx, end_idx - 1).astype(np.int64)
    return indices


def get_fmri(fmri_dir, ROI):
  """This function loads fMRI data into a numpy array for to a given ROI.
  Parameters
  ----------
  fmri_dir : str
    path to fMRI data.
  ROI : str
    name of ROI.

  Returns
  -------
  np.array
    matrix of dimensions #train_vids x #repetitions x #voxels
    containing fMRI responses to train videos of a given ROI
  """

  # Loading ROI data
  ROI_file = os.path.join(fmri_dir, ROI + ".pkl")
  ROI_data = load_dict(ROI_file)
  # averaging ROI data across repetitions
  ROI_data_train = np.mean(ROI_data["train"], axis=1)
  if ROI == "WB":
    voxel_mask = ROI_data['voxel_mask']

    return ROI_data_train, voxel_mask

  return ROI_data_train


def load_dict(filename_):
  with open(filename_, 'rb') as f:
    u = pickle._Unpickler(f)
    u.encoding = 'latin1'
    ret_di = u.load()
    # print(p)
    # ret_di = pickle.load(f)
  return ret_di


def fit_and_predict(voxel_idx, X_train, z_train, z_test, X_test):
    # Instantiate and fit the model
    model = RidgeCV(alphas=[1e-3, 1e-2, 1e-1, 1, 10, 100,1e3], cv=5)
    
    # z_train = np.nan_to_num(z_train)
    # X_train = np.nan_to_num(z_train)
    
    model.fit(np.nan_to_num(z_train), np.nan_to_num(X_train[:, voxel_idx]))
    
    # Predict on test data
    y_pred = model.predict(np.nan_to_num(z_test))
    corr = np.corrcoef(y_pred, X_test[:, voxel_idx])[0, 1]
    # wandb.log({"corr":corr,"voxel_idx":voxel_idx})
    
    return model, corr




## concatenate visual ROIs

sub = 'sub01'  
ROI = 'WB'  # @param ["WB", "V1", "V2","V3", "V4", "LOC", "EBA", "FFA","STS", "PPA"]

X=[]


## LOAD DATA

######## fMRI data loader wrapper code ###################################
fmri_dir = './participants_data_v2021'

if ROI=="WB":
    if ROI == "WB":  # Loading whole brain data
      track = "full_track"  # stored in full_track directory
    else:  # Loading ROI data
      track = "mini_track" # stored in mini_track directory

    # get the right track directory depending on whole brain/ROI choice
    track_dir = os.path.join(fmri_dir, track)

    # get the selected subject's directory
    sub_fmri_dir = os.path.join(track_dir, sub)

    # Load the fMRI data for the selected subject and ROI
    if track == "full_track":
        fmri_train_all,_ = get_fmri(sub_fmri_dir, ROI)
    else:
        fmri_train_all = get_fmri(sub_fmri_dir, ROI)
    X = fmri_train_all
else:
    for ROI in ["V1", "V2","V3", "V4", "LOC", "EBA", "FFA", "PPA"]:
        if ROI == "WB":  # Loading whole brain data
          track = "full_track"  # stored in full_track directory
        else:  # Loading ROI data
          track = "mini_track" # stored in mini_track directory

        # get the right track directory depending on whole brain/ROI choice
        track_dir = os.path.join(fmri_dir, track)

        # get the selected subject's directory
        sub_fmri_dir = os.path.join(track_dir, sub)

        # Load the fMRI data for the selected subject and ROI
        if track == "full_track":
            fmri_train_all,_ = get_fmri(sub_fmri_dir, ROI)
        else:
            fmri_train_all = get_fmri(sub_fmri_dir, ROI)

        X.append(fmri_train_all)

    X= np.concatenate(X,1)
    
    
    
vid_id = 1  # @param {type: "integer"}
video_dir = './AlgonautsVideos268_All_30fpsmax'

########### Video display code #################################################
video_list = glob.glob(video_dir + '/*.mp4')
video_list.sort()



# video clip consists of 300 frames (10 seconds at 30 FPS)
# file_path = hf_hub_download(
#     repo_id="nielsr/video-demo", filename="eating_spaghetti.mp4", repo_type="dataset"
# )
# container = av.open(video_list[0])

device="cuda:1"

container = av.open(video_list[vid_id])

# sample frames
num_frames =8 
indices = sample_frame_indices(
    clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames
)
frames = read_video_pyav(container, indices)

print(frames.shape)
processor = AutoProcessor.from_pretrained("microsoft/xclip-base-patch32")
# model = XCLIPVisionModel.from_pretrained("microsoft/xclip-base-patch32").to(device)
clip_model = XCLIPModel.from_pretrained("microsoft/xclip-base-patch32").to(device)


pixel_values = processor(videos=list(frames), return_tensors="pt").pixel_values.to(device)

batch_size, num_frames, num_channels, height, width = pixel_values.shape
pixel_values = pixel_values.reshape(-1, num_channels, height, width)

outputs = clip_model.get_video_features(pixel_values.unsqueeze(0))

print("[INFO] Extracting XCLIP video embeddings..")


video_embs=[]
bad_indices=[]
with torch.no_grad():
    for video_id in tqdm.trange(1000):

      try:
        ## encode all the videos
        container = av.open(video_list[video_id])

        # sample frames
        num_frames =8 
        indices = sample_frame_indices(
            clip_len=num_frames, frame_sample_rate=4, seg_len=container.streams.video[0].frames
        )
        frames = read_video_pyav(container, indices)

        pixel_values = processor(videos=list(frames), return_tensors="pt").pixel_values.to(device)

        batch_size, num_frames, num_channels, height, width = pixel_values.shape
        pixel_values = pixel_values.reshape(-1, num_channels, height, width)

        # outputs = model(pixel_values)
        outputs=clip_model.get_video_features(pixel_values.unsqueeze(0)).cpu()
        # last_hidden_state = outputs.pooler_output.cpu()
        # video_embs.append(outputs.pooler_output.cpu())
        video_embs.append(outputs.cpu())
      except:
        print(f"bad video id {video_id}")
        bad_indices.append(video_id)
        

video_embs=torch.stack(video_embs)

#remove bad indices from X
X=np.delete(X,bad_indices,0)

#remove bad indices from video_list
video_list=[video_list[i] for i in range(len(video_list)) if i not in bad_indices]


torch.save(video_embs, "video_embs.pt")

#check shapes
print(X.shape)
print(video_embs.shape)


print("[INFO] splitting train/test")

indices=np.arange(len(X))

train_indices,test_indices = train_test_split(indices, test_size=.10,random_state=42)

X_train = X[train_indices]
X_test = X[test_indices]

z_train = video_embs[train_indices].squeeze().numpy()
z_test = video_embs[test_indices].squeeze().numpy()

train_videos = np.array(video_list)[train_indices]
test_videos = np.array(video_list)[test_indices]

print(X_train.shape, X_test.shape)
print(z_train.shape, z_test.shape)
print(train_videos.shape, test_videos.shape)


# save the train and test videos, X and z
base_path="./"
## save the models and the correlations
data_path=opj(base_path,"data")
if not os.path.exists(data_path):
    os.makedirs(data_path)

np.save(opj(data_path,f"X_train_{sub}.npy"),X_train)
np.save(opj(data_path,f"X_test_{sub}.npy"),X_test)
np.save(opj(data_path,f"z_train_{sub}.npy"),z_train)
np.save(opj(data_path,f"z_test_{sub}.npy"),z_test)
np.save(opj(data_path,f"train_videos_{sub}.npy"),train_videos)
np.save(opj(data_path,f"test_videos_{sub}.npy"),test_videos)



print("[INFO] Fitting encoding models..")
 
    

# Number of cores to use for parallel processing
n_jobs = 256  # Adjust this based on your system's capabilities

# Run the model fitting and prediction in parallel
results = Parallel(n_jobs=n_jobs)(delayed(fit_and_predict)(
    voxel_idx, X_train, z_train, z_test, X_test) for voxel_idx in tqdm.trange(X_train.shape[1]))

#X_train.shape[1]
# Unpack results
voxel_models, voxel_corrs = zip(*results)

base_path="./"
## save the models and the correlations
models_path=opj(base_path,"models")
if not os.path.exists(models_path):
    os.makedirs(models_path)


np.save(opj(models_path,f"voxel_corrs_{sub}.npy"),np.array(voxel_corrs))

#save the models with pickle
import pickle

with open(opj(models_path,f"voxel_models_{sub}.pkl"),"wb") as f:
    pickle.dump(voxel_models,f)

print("[INFO] Encoding models saved..")