
import torch
import pickle
import ipdb
import h5py 
import numpy as np
from tqdm import tqdm
import torch.nn as nn

def split(x_s, y_s, percent):
    n = len(x_s)
    n1 = int(percent * n)
    x1 = x_s[:n1]
    x2 = x_s[n1:]

    y1 = y_s[:n1]
    y2 = y_s[n1:]

    return x1, x2, y1, y2

def reshape_param(grad_hat, theta_0):
    start_idx = 0
    grad_list = []
    for param in theta_0:
        shape = param.shape
        length = np.prod(shape)
        params = grad_hat[start_idx : start_idx + length].reshape(shape)
        grad_list.append(params)
        start_idx += length

    return grad_list

def flatten_params_torch(params, num_params):
    device = params[0].device
    to_return = []

    to_return = torch.zeros(num_params).to(device)
    start_idx = 0
    for ix, grad in enumerate(params):
        grad_flat = grad.view(-1)
        len_grad = len(grad_flat)
        to_return[start_idx : start_idx + len_grad] = grad_flat
        start_idx += len_grad
    return to_return

def flatten_params(params_list):
    flattened_params = []
    for param in params_list:
        flattened_params.extend(param.reshape(-1).cpu().data.numpy())
    return np.array(flattened_params)

def get_accuracy_score(y_pred, y_true):
    class_pred = torch.argmax(y_pred, -1).float()
    accuracy = torch.mean((class_pred == y_true.float())*1.0)
    return accuracy.cpu().data.numpy()

def create_h5_file(file_path, nn_param, task_param, z_np, theta_pre_np = None):
    num_nn_params = len(nn_param)
    num_task_params = len(task_param)
    with h5py.File(file_path, 'w') as hf:
        hf.create_dataset('nn_params', data=nn_param[None,:], maxshape=(None, num_nn_params))
        hf.create_dataset('task_params', data=task_param[None,:], maxshape=(None, num_task_params))

        num_z = len(z_np)
        hf.create_dataset('z', data=z_np[None,:], maxshape=(None, num_z))

        if theta_pre_np is not None:
            hf.create_dataset('pre_params', data=theta_pre_np[None, :], maxshape=(None, num_nn_params))

def append_h5_file(file_path, nn_param, task_param, z_np, theta_pre_np):
    nn_param = nn_param[None, :]
    task_param = task_param[None, :]
    if theta_pre_np is not None:
        theta_pre_np = theta_pre_np[None, :]
    with h5py.File(file_path, 'a') as hf:
        hf['nn_params'].resize((hf['nn_params'].shape[0] + nn_param.shape[0]), axis=0)
        hf["nn_params"][-nn_param.shape[0]:] = nn_param

        hf['task_params'].resize((hf['task_params'].shape[0] + task_param.shape[0]), axis=0)
        hf["task_params"][-task_param.shape[0]:] = task_param

        z_np = z_np[None, :]
        hf['z'].resize((hf['z'].shape[0] + z_np.shape[0]), axis=0)
        hf["z"][-z_np.shape[0]:] = z_np

        if theta_pre_np is not None:
            hf['pre_params'].resize((hf['pre_params'].shape[0] + theta_pre_np.shape[0]), axis=0)
            hf["pre_params"][-theta_pre_np.shape[0]:] = theta_pre_np


def get_accuracy_score(y_pred, y_true):
    class_pred = torch.argmax(y_pred, -1).float()
    if len(y_true.shape) == 2:
        y_true = y_true.squeeze(-1)
    accuracy = torch.mean((class_pred == y_true.float())*1.0)
    return accuracy.cpu().data.numpy()


def rapid_learning_reg(theta_grads):
    first_layer = theta_grads[0]
    second_layer = theta_grads[1]
    third_layer = theta_grads[2]
    _lambda = 0.7
    weighted_avg = torch.mean(first_layer**2) + _lambda*torch.mean(second_layer**2) + _lambda**2 * torch.mean(third_layer**2)
    return -weighted_avg


def flatten_v_torch(self, v):
    flattened = []
    for param in v:
        param_flattened = param.cpu().data.numpy().reshape(-1)
        flattened.extend(param_flattened)
    return torch.Tensor(flattened)



def recompute_eigen_vectors(big_batch_size, learner, dloader, device):
    if big_batch_size > len(dloader.dataset):
        big_batch_size = len(dloader.dataset)

    big_batch = dloader.dataset.sample_big_batch(big_batch_size)
    learner.initialize_gradient_memory(big_batch_size, device)

    for big_x_s_batch, big_y_s_batch, big_x_q_batch, big_y_q_batch, _ in tqdm(big_batch, total=big_batch_size, disable=True):
        x_s = big_x_s_batch.to(device)
        x_q = big_x_q_batch.to(device)
        y_s = big_y_s_batch.to(device)
        y_q = big_y_q_batch.to(device)

        predicted_y1 = learner(x_s)
        l1 = learner.criterion(predicted_y1, y_s)
        _, theta_grad_s = learner.adapt_maml(l1, return_grads=True) ## Return the gradients
        learner.add_to_memory(theta_grad_s) ## Do this for a burn in period

    learner.compute_eigen_vectors()


def reshape_time_series_x(x, traj_length, image_encoder, linear_layer):
    len_seq = len(x)
    img_width = x.shape[-1]
    x_imgs = x[:, :, :3, :, :].contiguous()
    x_imgs = x_imgs.view(len_seq * traj_length, 3, img_width, img_width)

    actions = x[:, :, 3:, 0, 0].view(len_seq, -1) # 20 x (traj_length * action_dim)

    x_encoded = image_encoder(x_imgs)
    x_encoded_dim = x_encoded.shape[-1]
    x_encoded = linear_layer(x_encoded.view(len_seq, traj_length * x_encoded_dim))

    x_cat_action = torch.cat((x_encoded, actions), -1)
    return x_cat_action

def reshape_time_series_input(x, y, traj_length, image_encoder, linear_layer):

    x_cat_action = reshape_time_series_x(x, traj_length, image_encoder, linear_layer)
    y_encoded = image_encoder(y)

    x = x_cat_action
    y = y_encoded

    return x,y


def support_to_img(support_x, support_y, img_size, device):
    img = torch.zeros((img_size, img_size, 1)).to(device)

    for i in range(len(support_x)):
        x_coord = support_x[i].long()
        pixel_value = support_y[i]
        img[x_coord[0], x_coord[1], :] = pixel_value
    
    return img.float().permute(2,0,1).unsqueeze(0)


def create_hyper_network(num_hypernet_layers, hidden_encoder, num_output_params):
    if num_hypernet_layers == 1:
        hypernet_modules = [nn.Linear(hidden_encoder, num_output_params)]
    else:
        hidden = 128
        hypernet_modules = []
        hypernet_modules.append(nn.Linear(hidden_encoder, hidden))
        hypernet_modules.append(nn.ReLU())
        for i in range(num_hypernet_layers - 1):
            hypernet_modules.append(nn.Linear(hidden, hidden))
            hypernet_modules.append(nn.ReLU())
        hypernet_modules.append(nn.Linear(hidden, num_output_params))

    hyper_net = nn.Sequential(*hypernet_modules)
    return hyper_net