如何在pytorch中使用大型数据集

时间 2019-02-19
阅读 96
点赞 0
收藏 0
连接guillaumea

我有一个巨大的数据集,不适合内存(150g),我正在寻找在pytorch中使用它的最佳方法。数据集由多个.npz每个10万个样本的文件。我试图建立一个Dataset

class MyDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.files = os.listdir(self.path)
        self.file_length = {}
        for f in self.files:
            # Load file in as a nmap
            d = np.load(os.path.join(self.path, f), mmap_mode='r')
            self.file_length[f] = len(d['y'])

    def __len__(self):
        raise NotImplementedException()

    def __getitem__(self, idx):                
        # Find the file where idx belongs to
        count = 0
        f_key = ''
        local_idx = 0
        for k in self.file_length:
            if count < idx < count + self.file_length[k]:
                f_key = k
                local_idx = idx - count
                break
            else:
                count += self.file_length[k]
        # Open file as numpy.memmap
        d = np.load(os.path.join(self.path, f_key), mmap_mode='r')
        # Actually fetch the data
        X = np.expand_dims(d['X'][local_idx], axis=1)
        y = np.expand_dims((d['y'][local_idx] == 2).astype(np.float32), axis=1)
        return X, y

但是当一个样本被实际提取时,它需要30多秒,看起来像整个.npz在RAM中打开、存储并访问正确的索引。 如何提高效率?

编辑

这似乎是对.npz文件夹see post但是有更好的方法吗?

解决方案建议

正如@covariantmonkey所建议的,lmdb是一个不错的选择。现在,因为问题来自.npz文件而不是memmap,我通过拆分来重新构建数据集.npz将文件打包为多个.npy文件夹。我现在可以用同样的逻辑memmap完全有意义,而且速度非常快(加载一个样本需要几毫秒)。

建议答案

个人有多大.npz文件夹?一个月前我也处于同样的困境。各种各样forum帖子,谷歌搜索后来我去了lmdb路线。这是我做的

  1. 把大数据集分成足够小的文件,我可以放在GPU中——每个文件基本上都是我的小批量文件。我在这个阶段没有优化加载时间只是记忆。
  2. 使用创建LMDB索引key = filenamedata = np.savez_compressed(stff)

lmdb帮你保管好MMAP,而且装得太快了。

当做,

PS:savez_compessed需要一个byte对象,以便您可以执行类似的操作

output = io.BytesIO()
np.savez_compressed(output, x=your_np_data)
#cache output in lmdb
👍 0