Source code for torch_skeleton.datasets.ucla

import os.path as osp

import json
import numpy as np

from torch.utils.data import Dataset

import torch_skeleton.utils as skel_utils

from typing import Callable, Optional


[docs]class UCLA(Dataset): """`NW-UCLA <http://wangjiangb.github.io/my_data.html>`_ Dataset. Args: root (str): root directory of dataset split (str): split type, either ``"train"`` or ``"val"`` transform (``Transform``): transform to apply to dataset """ def __init__( self, root=".", split="train", transform: Optional[Callable] = None, ): super().__init__() self.root = osp.join(root, "NW-UCLA") self.transform = transform path = osp.join(self.root, "all_sqe.zip") if not skel_utils.downloaded(path): skel_utils.download_url( "https://www.dropbox.com/s/10pcm4pksjy6mkq/all_sqe.zip?dl=1", path ) skel_utils.extract_zip(path, self.root) paths = skel_utils.listdir(self.root, ext="json") self.file_paths = filter_split(paths, split) def __getitem__(self, index): path = self.file_paths[index] with open(path) as f: data = json.load(f) x = np.array(data["skeletons"]).astype(float) x = np.expand_dims(x, axis=0) y = int(data["label"]) - 1 if self.transform is not None: x = self.transform(x) return x, y def __len__(self): return len(self.file_paths)
def filter_split(paths, split): split_is_train = split == "train" split_paths = [] for path in paths: # first two cameras for train, third camera for test camera_id = get_camera(path) in_train = camera_id != 3 in_split = in_train == split_is_train if in_split: split_paths.append(path) return split_paths def get_camera(path): file_name = osp.basename(path).split(".")[0] return int(file_name.split("_")[3][1:])