Source code for torch_skeleton.datasets.nw_ucla

import os.path as osp

import json
import numpy as np

from torch.utils.data import Dataset

import torch_skeleton.utils as skel_utils

from typing import Callable, Optional


[docs]class UCLA(Dataset): """`NW-UCLA <http://wangjiangb.github.io/my_data.html>`_ Dataset. Args: root (str): root directory of dataset transform (``Transform``): transform to apply to dataset """ def __init__( self, root=".", transform: Optional[Callable] = None, ): super().__init__() self.root = osp.join(root, "NW-UCLA") self.transform = transform path = osp.join(self.root, "all_sqe.zip") if not skel_utils.downloaded(path): skel_utils.download_url( "https://www.dropbox.com/s/10pcm4pksjy6mkq/all_sqe.zip?dl=1", path ) skel_utils.extract_zip(path, self.root) self.file_paths = skel_utils.listdir(self.root, ext="json") def __getitem__(self, index): path = self.file_paths[index] with open(path) as f: data = json.load(f) x = np.array(data["skeletons"]).astype(float) x = np.expand_dims(x, axis=0) y = data["label"] if self.transform is not None: x = self.transform(x) return x, y def __len__(self): return len(self.file_paths)