Source code for torch_skeleton.transforms.transforms

from . import functional as F

__all__ = [
    "Compose",
    "PadBodies",
    "PadFrames",
    "SampleFrames",
    "SelectKBodies",
    "SplitFrames",
    "CenterJoint",
    "ParallelBone",
    "RandomShift",
    "RandomRotate",
    "SortByMotion",
    "DenoiseByLength",
    "DenoiseBySpread",
    "DenoiseByMotion",
    "MergeBodies",
    "RemoveZeroFrames",
    "Denoise",
]


[docs]class Compose: """Composes several transforms together. This transform does not support torchscript. Please, see the note below. Args: transforms (list of ``Transform`` objects): list of transforms to compose. Example: >>> transforms.Compose([ >>> transforms.CenterCrop(10), >>> transforms.PILToTensor(), >>> transforms.ConvertImageDtype(torch.float), >>> ]) .. note:: In order to script the transformations, please use ``torch.nn.Sequential`` as below. >>> transforms = torch.nn.Sequential( >>> transforms.CenterCrop(10), >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), >>> ) >>> scripted_transforms = torch.jit.script(transforms) Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require `lambda` functions or ``PIL.Image``. """ def __init__(self, transforms): self.transforms = transforms def __call__(self, x): for t in self.transforms: x = t(x) return x
[docs]class PadBodies: """Pad each frame with zeros to match number of bodies. Args: max_bodies (int): number of bodies to pad to """ def __init__(self, max_bodies, **kwargs): self.max_bodies = max_bodies self.pad_kwargs = kwargs def __call__(self, x): return F.pad_bodies(x, max_bodies=self.max_bodies, **self.pad_kwargs)
[docs]class PadFrames: """Pad sequence with zeros to match number of frames. Args: max_frames (int): number of frames to pad to axis (int): time dimension axis, default is 1 """ def __init__(self, max_frames, axis=1, **kwargs): self.max_frames = max_frames self.pad_kwargs = kwargs self.axis = axis def __call__(self, x): return F.pad_frames( x, max_frames=self.max_frames, axis=self.axis, **self.pad_kwargs )
[docs]class SampleFrames: """Uniformly sample frames from sequence. Args: num_frames (int): number of frames to sample num_clips (int): number of clips to sample axis (int): axis to stack sampled clips """ def __init__(self, num_frames, num_clips=1, axis=0) -> None: self.num_frames = num_frames self.num_clips = num_clips self.axis = axis def __call__(self, x): return F.sample_frames( x, num_frames=self.num_frames, num_clips=self.num_clips, axis=self.axis )
[docs]class SelectKBodies: """Select K bodies from sequeqnce based on motion. Args: k (int): number of bodies to sort """ def __init__(self, k): self.k = k def __call__(self, x): return F.select_k_bodies(x, k=self.k)
[docs]class SplitFrames: """Split frame into N frames each containing one skeleton.""" def __call__(self, x): return F.split_frames(x)
[docs]class CenterJoint: """Center skeleton on joint. Args: joint_id (int): joint id to center skeleton all (bool): set to False to center along initial frame """ def __init__(self, joint_id=1, all=False): self.joint_id = joint_id self.all = all def __call__(self, x): return F.sub_center_joint(x, joint_id=self.joint_id, all=self.all)
[docs]class ParallelBone: """Transform skeleton so two joints are parallel to a certain axis. Args: first_id (int): first joint id second_id (int): second joint id axis (int): axis to be parallel to 0, 1, 2, corresponds to x, y, z """ def __init__(self, first_id, second_id, axis=2) -> None: super().__init__() self.first_id = first_id self.second_id = second_id self.axis = axis def __call__(self, x): return F.parallel_bone( x, first_id=self.first_id, second_id=self.second_id, axis=self.axis )
[docs]class RandomShift: """Randomly shift skeleton sequences. Args: low (int): minimum shift high (int): maximum shift """ def __init__(self, low, high): self.low = low self.high = high def __call__(self, x): return F.random_shift(x, low=self.low, high=self.high)
[docs]class RandomRotate: """Randomly rotate skeleton sequences. Args: degrees (int): range of shift in degrees Example: # randomly select angles from -30 ~ 30 degrees and rotate >>> transforms.RandomRotate(30) """ def __init__(self, degrees): self.degrees = degrees def __call__(self, x): return F.random_rotate(x, self.degrees)
[docs]class SortByMotion: """Sort skeletons in a single frame by motion.""" def __call__(self, x): return F.select_k_bodies(x, k=x.shape[0])
[docs]class DenoiseByLength: """Denoise bodies by length of non zero frames. Removes bodies whose length is under a minimum length """ def __call__(self, x): return F.denoising_by_length(x)
[docs]class DenoiseBySpread: """Denoise bodies by length of frames under a certain x, y spread. Removes bodies where noisy frames exceed a certain ratio. Frames with higher x, y spread than the threshold are considered noisy. """ def __call__(self, x): return F.denoising_by_spread(x)
[docs]class DenoiseByMotion: """Denoise bodies by filtering frames within a certain range of motion. Valid frames by spread are selected to compute the motion of a body. Bodies whose motion is outside a certain low, high range is removed. """ def __call__(self, x): return F.denoising_by_motion(x)
[docs]class MergeBodies: """Merges different bodies who don't overlap into 2 actors. Bodies are expected to be sorted by motion. First body is selected as the main actor as it has the highest motion. subsequent bodies are compared with the main actor. If there aren't overlapping frames, the body is merged with the main actor. If there are overlapping frames, try to merge with the second actor. If there are overlapping frames with the second actor, the body is removed. """ def __call__(self, x): if x.shape[0] > 1: x = F.merge_bodies(x) x = F.remove_zero_frames(x) return x
[docs]class RemoveZeroFrames: """Remove frames where all bodies are zero from the sequence.""" def __call__(self, x): return F.remove_zero_frames(x)
[docs]class Denoise: """Denoise skeleton sequence used on NTU.""" def __call__(self, x): if x.shape[0] > 1: x = F.select_k_bodies(x, k=x.shape[0]) x = F.denoising_by_length(x) if x.shape[0] > 1: x = F.denoising_by_spread(x) if x.shape[0] > 1: x = F.denoising_by_motion(x) if x.shape[0] > 1: x = F.merge_bodies(x) x = F.remove_zero_frames(x) return x