[pytorch] gather와 scatter
gather reference : https://pytorch.org/docs/stable/generated/torch.gather.html
torch.gather — PyTorch 2.0 documentation
Shortcuts
pytorch.org
scatter reference : https://pytorch.org/docs/stable/generated/torch.Tensor.scatter_.html#torch.Tensor.scatter_
torch.Tensor.scatter_ — PyTorch 2.0 documentation
Shortcuts
pytorch.org
gather는 input의 텐서에서 index에 따라 값들을 가져오는 텐서로 만드는 함수이다.
scatter는 src의 텐서에서 값들을 가져와 index에 따라 뿌려주는 함수이다.
즉, index를 target의 기준인지 src의 기준인지가 차이점이다.
scatter에는 reduce라는 parameter가 있는데, add나 mul연산이 가능하다. 그러나 scatter_add와 같은 함수로 구현되어 있고, 추후 릴리즈에서 scatter의 reduce 인자는 사용되지 않을 예정이므로 되도록 사용하지 말자.
gather 함수
torch.gather(input,dim,index,*,sparse_grad=False,out=None) → Tensor
torch.gather 혹은 input.gater로 사용 가능하다.dim을 기준으로 index텐서의 나머치 차원의 index는 그대로 맵핑된다.
gather의 동작원리는 dim으로 정해진 축을 기준으로 나머지 dimension의 인덱스는 고정된다.
즉, index.shape=(2,4,4), dim=1 이라면 output[m][j][k] input[m][index[m][j][l]][l]의 값을 가져온다.
이 결과, output.shape=(2,4,4)이 된다.
바로 이해하기가 어렵지만, 차근차근 접근해보자.
다시 풀이하자면, index에 들어가는 index값들은 dim으로 설정된 차원의 인덱스를 의미한다.
그러므로 input.shape=(m,n,l)이라면 index 텐서 내부의 값들은 0~n-1의 값만 가질 수 있다.
dim으로 설정되지 않은 차원은 index가 1:1 맵핑되고, dim으로 설정된 차원은 index텐서의 값에 따라 target dimension의 index를 가져온다.
좀 더 이해하기 위해 1~3차원으로 예시를 들어보자.
1차원 gather
src = torch.arange(12,0,-1) #tensor([12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1])
index_tensor = torch.tensor([2,4,6])
src.gather(0,index_tensor) #tensor([10, 8, 6])
index_tensor=[2,4,6]이다.
src에서 index가 2,4,6인 10,8,6을 가져오게 된다.
2차원 gather
2차원 src에서 대각 성분만 가져오도록 구현해보자.
src.shape=(3,4)라면, output.shape=(3,1) 또는 (1,3)이 되어야 한다.
dim=0일 때
[0,0],[1,1],[2,2]를 가져와야 한다.
그러므로 index_tensor.shape=(1,3)이 되어야 한다.
a=torch.arange(12).reshape(3,4)
#tensor([[ 0, 1, 2, 3],[ 4, 5, 6, 7],[ 8, 9, 10, 11]])
index_tensor = torch.tensor([[0,1,2]])
result = a.gather(0,index_tensor)
#tensor([[ 0, 5, 10]])
result = result.squeeze(0) #2차원으로 나온 결과를 1차원으로 압축
#tensor([ 0, 5, 10])
dim=1일 때
dim=0의 인덱스를 0,1,2로 접근해야 하므로, index_tensor.shape=(3,1)이 되어야 한다.
a=torch.arange(12).reshape(3,4) #위의 예제와 같은 src
index_tensor = torch.tensor([[0],[1],[2]])
result = a.gather(1,index_tensor) #index_tensor.shape=(3,1), src.shape=(3,4)
#tensor([[0],[5],[10]])
result.reshape(-1,) #이렇게 한줄로 펼칠수도 있다. view를 써도 된다.
#tensor([ 0, 5, 10])
여기까지는 한눈에 들어오고 이해하기 쉽다.
3차원 gather
3차원에서 gather를 제대로 이해했다면, 그 이후의 차원도 쉽게 다룰 수 있을 것이다.
src.shape=(3,2,4)를 사용해서 대각행렬을 추출해보자.
기준으로 정해진 부분을 제외하고 나머지 dimension의 대각성분만 가져오면 된다.
여기서부터는 원하는 결과에 따라 dim을 잘 설정해야하므로, dim=0의 텐서에서 나머지 차원의 대각 성분을 추출하는 예제로 진행하고자 한다,
src=torch.arange(24).reshape(3,2,4)
'''
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23]]])
'''
위의 텐서를 기준으로 [[0,5],[8,13],[16,21]]을 추출해보자
원하는 값을 추출하려면 [0,0,0],[0,1,1],[1,0,0],[1,1,1],[2,0,0],[2,1,1]을 추출해야한다.
이때, dim=0을 기준으로 잡으면 dim=2가 0인 값은 구하기 쉬우나 1인 값을 구하기 어려워진다.
[?,0,0]과 [?,1,1]을 구하려면 dim2의 길이가 2가 되어야 하는데, 그렇게 하면 필요없는 [?,0,1]과 [?,1,0]도 추출하게 될 것이고, 이 값만 제거하기 매우 까다로워진다.
dim=1을 기준으로 잡으면 어떨까?
dim0과 dim2의 index만 보면 [0,0,0],[0,1,1],[1,0,0],[1,1,1],[2,0,0],[2,1,1]이 된다.
dim0은 0~2, dim2는 0~1로 깔끔하게 잡힌다.
그러면 index_tensor.shape=(3,1,2)로 잡으면 된다는 것을 알 수 있다.
src = torch.arange(24).reshape(3,2,4)
'''
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7]],
[[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23]]])
'''
idx_tensor = torch.tensor([[[0,1]],[[0,1]],[[0,1]]]) #shape=(3,1,2)
src.gather(1,idx_tensor)
'''
tensor([[[ 0, 5]],
[[ 8, 13]],
[[16, 21]]])
'''
그리고 dim=2의 관점에서 보자
[0,0,0],[0,1,1],[1,0,0],[1,1,1],[2,0,0],[2,1,1]
dim=2 또한, 추출하기 쉬운 형태로 나옴을 알 수 있고, index_tensor.shape=(3,2,1)로 잡으면 될 것이다.
src = torch.arange(24).reshape(3,2,4)
idx_tensor = torch.tensor([[[0],[1]],[[0],[1]],[[0],[1]]]) #shape(3,2,1)
src.gather(2,idx_tensor)
'''
tensor([[[ 0],
[ 5]],
[[ 8],
[13]],
[[16],
[21]]])
'''
scatter
scatter는 src의 값들을 target 텐서로 뿌려준다고 생각하면 된다.
만약 src값이 상수라면 해당 값을 뿌리고, tensor라면 텐서에 있는 값들을 뿌린다.
gather를 잘 이해했다면, scatter는 쉬울 것이므로 간단한 예제만 남기고 끝내고자 한다.
src = torch.arange(5)
target = torch.zeros(10,dtype=torch.int64)
idx_tensor = torch.tensor([1,3,7,8,2])
target.scatter_(0,idx_tensor,b) #tensor([0, 0, 4, 1, 0, 0, 0, 2, 3, 0])
src = torch.arange(1, 11).reshape((2, 5))
index = torch.tensor([[0, 1, 2, 0],[1,2,0,1]])
torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
'''
tensor([[1, 0, 8, 4, 0],
[6, 2, 0, 9, 0],
[0, 7, 3, 0, 0]])
'''