import torch
import matplotlib.pyplot as plt
import numpy as np
from IPython import embed
import random
import os
import pickle
import argparse

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


class TrajDataset(torch.utils.data.Dataset):
    def __init__(self, path, config):
        self.config = config
        self.filepath = path

        if isinstance(self.filepath, list):
            self.trajs = []
            for filepath in self.filepath:
                file = open(filepath, 'rb')
                self.trajs += pickle.load(file)
                file.close()
        else:
            file = open(self.filepath, 'rb')
            self.trajs = pickle.load(file)
            file.close()

        rollin_xs = []
        rollin_us = []
        rollin_xps = []
        rollin_rs = []
        states, actions = [], []
        # Qs = []

        print("Running through trajs...")
        for i, traj in enumerate(self.trajs):
            if i % 500 == 0:
                print(i)

            # Handle PPO data separately
            if traj['rollin_xs'].shape[0] != traj['rollin_us'].shape[0]:
                rollin_xs.append(traj['rollin_xs'])
                rollin_us.append(traj['rollin_us'][:-1])
                rollin_xps.append(traj['rollin_xps'])
                rollin_rs.append(traj['rollin_rs'].squeeze()[:-1])
            else:
                rollin_xs.append(traj['rollin_xs'])
                rollin_us.append(traj['rollin_us'])
                rollin_xps.append(traj['rollin_xps'])
                rollin_rs.append(traj['rollin_rs'])

            states.append(traj['state'])
            actions.append(traj['action'])
            # Qs.append(traj['Q'].flatten())
            
        rollin_xs = np.array(rollin_xs)
        rollin_us = np.array(rollin_us)
        rollin_xps = np.array(rollin_xps)
        rollin_rs = np.array(rollin_rs)
        if len(rollin_rs.shape) < 3:
            rollin_rs = rollin_rs[:,:,None]
        print('Shape of rollin_xs: ', rollin_xs.shape)

        states = np.array(states)
        actions = np.array(actions)
        # Qs = np.array(Qs)

        dx = rollin_xs.shape[-1]
        du = rollin_us.shape[-1]

        self.H = rollin_xs.shape[1]


        self.ds = {
            'states': torch.tensor(states).float().to(device),
            'actions': torch.tensor(actions).float().to(device),
            'rollin_xs': torch.tensor(rollin_xs).float().to(device),
            'rollin_us': torch.tensor(rollin_us).float().to(device),
            'rollin_xps': torch.tensor(rollin_xps).float().to(device),
            'rollin_rs': torch.tensor(rollin_rs).float().to(device),
            'zeros': torch.zeros(len(states), dx**2 + du + 1).float().to(device),
            'zerosQ': torch.zeros(len(states), self.H, dx**2).float().to(device),
            # 'Qs': torch.tensor(Qs).float().to(device) 
        }
        


    def __len__(self):
        'Denotes the total number of samples'
        return len(self.ds['states'])

    def __getitem__(self, i):
        'Generates one sample of data'
        if self.config['shuffle']:
            permutation = torch.randperm(self.H)
            res = {
                'states': self.ds['states'][i],
                'actions': self.ds['actions'][i],
                'rollin_xs': self.ds['rollin_xs'][i][permutation,:],
                'rollin_us': self.ds['rollin_us'][i][permutation,:],
                'rollin_xps': self.ds['rollin_xps'][i][permutation,:],
                'rollin_rs': self.ds['rollin_rs'][i][permutation,:],
                'zeros': self.ds['zeros'][i],
                'zerosQ': self.ds['zerosQ'][i],
                # 'Qs': self.ds['Qs'][i]
            }
        else:
            res = {
                'states': self.ds['states'][i],
                'actions': self.ds['actions'][i],
                'rollin_xs': self.ds['rollin_xs'][i],
                'rollin_us': self.ds['rollin_us'][i],
                'rollin_xps': self.ds['rollin_xps'][i],
                'rollin_rs': self.ds['rollin_rs'][i],
                'zeros': self.ds['zeros'][i],
                'zerosQ': self.ds['zerosQ'][i],
                # 'Qs': self.ds['Qs'][i]
            }

        return res


if __name__ == '__main__':
    config = {'shuffle': True}
    n_envs = 1000
    n_hists = 1
    n_samples = 1
    H = 10
    dim = 4
    path_train = f'datasets/trajs_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_train.pkl'
    path_test = f'datasets/trajs_envs{n_envs}_hists{n_hists}_samples{n_samples}_H{H}_d{dim}_test.pkl'
    ds = TrajDataset(path_train, config)
    embed()
    ds[0]

