Source code for torch_skeleton.datasets.base_dataset

import os.path as osp

import shutil
import tempfile

import torch
from torch.utils.data import Dataset

import torch_skeleton.utils as skel_utils

from typing import Callable, Optional


[docs]class DiskCache(Dataset): """Cache ``Dataset`` instance to disk. Caches output of dataset to disk by creating a temporary directory at root. Args: root (str): root directory of cache dataset (``Dataset``): dataset to cache """ def __init__( self, dataset: Dataset, root: str = ".", transform: Optional[Callable] = None, ): super().__init__() skel_utils.makedirs(root) self.temp_dir = tempfile.TemporaryDirectory(dir=root) self.root = self.temp_dir.name self.transform = transform skel_utils.makedirs(self.root) shutil.rmtree(self.root) skel_utils.makedirs(self.root) self.dataset = dataset def cache_path(self, index): return osp.join(self.root, f"{index}.pt") def __getitem__(self, index): path = self.cache_path(index) if osp.exists(path): x, y = torch.load(path) else: x, y = self.dataset[index] torch.save([x, y], self.cache_path(index)) if self.transform is not None: x = self.transform(x) return x, y def __len__(self): return len(self.dataset) def __del__(self): self.temp_dir.cleanup()
[docs]class Apply(Dataset): """Apply ``Transform`` to ``Dataset`` instance. Args: dataset (``Dataset``): dataset to apply transform to transform (``Transform``): transform to apply """ def __init__(self, dataset: Dataset, transform: Callable): super().__init__() self.dataset = dataset self.transform = transform def __getitem__(self, index): x, y = self.dataset[index] x = self.transform(x) return x, y def __len__(self): return len(self.dataset)