import torch
import pickle
import hydra
from omegaconf import DictConfig, OmegaConf
from utils.utils import evaluate, get_args


@hydra.main(config_path="conf", config_name="config_trans_pt2pkl")
def main(args: DictConfig):
    # Load your PyTorch model
    load_path = args.load_path
    save_path = args.save_path
    assert load_path is not None and save_path is not None
    model = torch.load(load_path, map_location=torch.device('cpu'))

    # Assuming the model is a state_dict
    weights = {k: v.numpy() for k, v in model.items()}

    weights['states'] = weights['obs']
    weights['next_states'] = weights['next_obs']
    weights['dones'] = weights['done'].astype(bool)
    weights['rewards'] = weights['ep_found_goal']

    # Remove the old 'obs' key
    del weights['obs']
    del weights['next_obs']
    del weights['done']
    del weights['ep_found_goal']

    # # swap channels of states and next_states from  (H, W, C) to (C, H, W)
    # weights['states'] = weights['states'].transpose(0, 3, 1, 2)
    # weights['next_states'] = weights['next_states'].transpose(0, 3, 1, 2)

    # transform everything into numpy arrays
    import numpy as np
    for key in weights.keys():
        weights[key] = np.array(weights[key])
        # transform weights[key] type the same as the original weights
        weights[key] = weights[key].astype(weights[key][0].dtype)

    weights['actions'] = weights['actions'].squeeze()
    # cut each array into arrays, according to dones. If it is a done, then cut it.
    # if it is not done, then keep it. I also need to create an length array to store the length of each trajectory.
    new_weights = {}
    max_len = 50
    new_weights['states'] = np.zeros((np.sum(weights['dones']), 
                                    max_len, *weights['states'].shape[1:]), dtype=weights['states'].dtype)
    new_weights['actions'] = np.zeros((np.sum(weights['dones']), 
                                    max_len), dtype=weights['actions'].dtype)
    new_weights['next_states'] = np.zeros((np.sum(weights['dones']),
                                            max_len, *weights['next_states'].shape[1:]), dtype=weights['next_states'].dtype)
    new_weights['dones'] = np.zeros((np.sum(weights['dones']), max_len), dtype=weights['dones'].dtype)
    new_weights['rewards'] = np.zeros((np.sum(weights['dones']), max_len), dtype=weights['rewards'].dtype)
    new_weights['final_ind'] = [0]
    traj_num = 0
    for i in range(1, len(weights['dones'])):
        if weights['dones'][i]:
            new_weights['states'][traj_num][0: i + 1 - new_weights['final_ind'][-1]] = weights['states'][new_weights['final_ind'][-1]:i + 1]
            new_weights['actions'][traj_num][0: i + 1 - new_weights['final_ind'][-1]] = weights['actions'][new_weights['final_ind'][-1]:i + 1]
            new_weights['next_states'][traj_num][0: i + 1 - new_weights['final_ind'][-1]] = weights['next_states'][new_weights['final_ind'][-1]:i + 1]
            new_weights['dones'][traj_num][0: i + 1 - new_weights['final_ind'][-1]] = weights['dones'][new_weights['final_ind'][-1]:i + 1]
            new_weights['rewards'][traj_num][0: i + 1 - new_weights['final_ind'][-1]] = weights['rewards'][new_weights['final_ind'][-1]:i + 1]
            new_weights['final_ind'].append(i + 1)
            traj_num += 1

    new_weights['lengths'] = np.zeros(len(new_weights['final_ind']) - 1, dtype=int)
    for i in range(1, len(new_weights['final_ind'])):
        new_weights['lengths'][i - 1] = new_weights['final_ind'][i] - new_weights['final_ind'][i - 1]
    del new_weights['final_ind']




    # Save the weights dictionary as a .npy file
    np.save(save_path, new_weights)


if __name__ == "__main__":
    main()