1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
| def angle_axis_to_rotation_matrix(angle_axis): """Convert 3d vector of axis-angle rotation to 4x4 rotation matrix
Args: angle_axis (Tensor): tensor of 3d vector of axis-angle rotations.
Returns: Tensor: tensor of 4x4 rotation matrices.
Shape: - Input: :math:`(N, 3)` - Output: :math:`(N, 4, 4)`
Example: >>> input = torch.rand(1, 3) # Nx3 >>> output = tgm.angle_axis_to_rotation_matrix(input) # Nx4x4 """ def _compute_rotation_matrix(angle_axis, theta2, eps=1e-6): k_one = 1.0 theta = torch.sqrt(theta2) wxyz = angle_axis / (theta + eps) wx, wy, wz = torch.chunk(wxyz, 3, dim=1) cos_theta = torch.cos(theta) sin_theta = torch.sin(theta)
r00 = cos_theta + wx * wx * (k_one - cos_theta) r10 = wz * sin_theta + wx * wy * (k_one - cos_theta) r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta) r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta r11 = cos_theta + wy * wy * (k_one - cos_theta) r21 = wx * sin_theta + wy * wz * (k_one - cos_theta) r02 = wy * sin_theta + wx * wz * (k_one - cos_theta) r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta) r22 = cos_theta + wz * wz * (k_one - cos_theta) rotation_matrix = torch.cat( [r00, r01, r02, r10, r11, r12, r20, r21, r22], dim=1) return rotation_matrix.view(-1, 3, 3)
def _compute_rotation_matrix_taylor(angle_axis): rx, ry, rz = torch.chunk(angle_axis, 3, dim=1) k_one = torch.ones_like(rx) rotation_matrix = torch.cat( [k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], dim=1) return rotation_matrix.view(-1, 3, 3)
_angle_axis = torch.unsqueeze(angle_axis, dim=1) theta2 = torch.matmul(_angle_axis, _angle_axis.transpose(1, 2)) theta2 = torch.squeeze(theta2, dim=1)
rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2) rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis)
eps = 1e-6 mask = (theta2 > eps).view(-1, 1, 1).to(theta2.device) mask_pos = (mask).type_as(theta2) mask_neg = (mask == False).type_as(theta2)
batch_size = angle_axis.shape[0] rotation_matrix = torch.eye(4).to(angle_axis.device).type_as(angle_axis) rotation_matrix = rotation_matrix.view(1, 4, 4).repeat(batch_size, 1, 1) rotation_matrix[..., :3, :3] = \ mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor return rotation_matrix
|