import torch
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import ipdb



class AntDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, data_path, dataset_size, noise_std=0, test=False, imitation=True):
        super().__init__()
        self.imitation = imitation
        self.noise_std = noise_std
        self.dataset_size = dataset_size

        postfix = 'test' if test else 'train'
        #size = 1000 if test else 20000
        size = 1000 if test else 5000

        if dataset == 'ant':
            fixed_path = os.path.join(data_path, 'ant_ppo', f'fixed_{postfix}_30_{size}.pkl')
            ppo_path = os.path.join(data_path, 'ant_ppo', f'ppo_{postfix}_30_{size}.pkl')
        elif dataset == 'ant-8':
            fixed_path = os.path.join(data_path, 'ant_ppo_8_leg', f'fixed_{postfix}_30_{size}.pkl')
            ppo_path = os.path.join(data_path, 'ant_ppo_8_leg', f'ppo_{postfix}_30_{size}.pkl')

        with open(fixed_path, 'rb') as f:
            self.fixed_policy_data = pickle.load(f)
            print("load", fixed_path)

        with open(ppo_path, 'rb') as f:
            self.rl_policy_data = pickle.load(f)
            print("load", ppo_path)

        indices = np.arange(len(self.rl_policy_data))
        rng = np.random.default_rng(42)
        rng.shuffle(indices)
        self.indices = indices[:self.dataset_size]


    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        idx = self.indices[idx]
        support_x, support_y, _, _, task_params = self.fixed_policy_data[idx]
        query_policy_x, query_policy_y = self.rl_policy_data[idx]


        if self.imitation:
            query_x = query_policy_x[:, :27]
            query_y = query_policy_x[:, 27:]
        else:
            query_x = query_policy_x
            query_y = query_policy_y

        support_x = torch.from_numpy(support_x).float()
        support_y = torch.from_numpy(support_y).float() + torch.randn(*support_y.shape) * self.noise_std

        query_x = torch.from_numpy(query_x).float()
        query_y = torch.from_numpy(query_y).float()


        task_params = torch.from_numpy(np.array(task_params)).float()

        return support_x, support_y, query_x, query_y, task_params




if __name__ == '__main__':
    dset = AntDataset('data', True)
    support_x, support_y, query_x, query_y, task_params = dset[0]