본문 바로가기

데이터사이언스/Graph Neural Networks

DGL 튜토리얼 - 그래프 만들기와 메세지 전달하기

Deep Graph Library(DGL)은 Pytorch, Tensorflow 등 Deep Learning 프레임워크를 기반으로 만들어진 python 패키지이다.

 

1. 그래프 만들기

 DGL에서 각 노드와 엣지는 그래프에 추가된 순서에 따라 unique한 ID를 가진다. DGLGraph를 만드는 가장 기본적인 방법은 dgl.graph() 메소드를 사용하는 것이다. 이 메소드는 엣지 집합을 입력으로 갖는다. 다음 코드는 아래와 같은 구조를 가지는 그래프를 만드는 예제이다.

dgl.graph()를 사용해 만들 그래프

import dgl
import torch as th

u = th.tensor([0, 0, 0, 1]) # 출발 노드
v = th.tensor([1, 2, 3, 3]) # 도착 노드
g = dgl.graph((u, v))

print(g) # 노드 개수 >> 4
# Node IDs
print(f'nodes: {g.nodes()}')

# Edge end nodes
print(f'edges: {g.edges()}')

위 코드에서 u, v는 각각 edge의 시작 node, 도착 node를 나타내고 있다. 예를 들어 u[0], v[0]은 0번 node에서 시작해서 2번 node로 향하는 노드를 나타낸다. dgl.graph()의 인자로 엣지를 넣어주어 그래프를 생성하고 있다. 이렇게 만들어진 그래프의 노드 정보는 g.nodes(), 엣지 정보는 g.edges()로 확인할 수 있다. 방향성이 없는 그래프를 만들기 위해서는 양방향 엣지를 만들어야 한다. dgl.to_bidirected() 함수를 사용해서 그래프의 엣지를 양방향 엣지로 바꿀 수 있다.

>>bg = dgl.to_bidirected(g)
>>bg.edges()
(tensor([0, 0, 0, 1, 1, 2, 3, 3]), tensor([1, 2, 3, 0, 3, 0, 0, 1]))

 

2. 노드와 엣지의 feature 설정

 DGL에서는 ndata, edata 인터페이스를 이용해 노드, 엣지의 feature를 각각 지정할 수 있다.

g.ndata['x'] = th.ones(g.num_nodes(), 3)               # node feature of length 3

weights = th.tensor([0.1, 0.6, 0.9, 0.7])
g.edata['w'] = weights                                # w라는 scalar edge feature 할당

print(g)

g.ndata['y'] = th.randn(g.num_nodes(), 5) # node feature name y

print(g.ndata['x'][1])                  # get node 1's feature x
print(g.edata['w'][th.tensor([0, 3])])  # get features of edge 0 and 3

 

 

3. 메세지 전달하기

 DGL에서 메세지 전달은 메세지 함수, aggregation을 진행하는 축약함수, 실제로 feature를 업데이트하는 업데이트 함수 세가지 함수를 거쳐서 이루어진다. 메세지 함수와 aggregation은 dgl.funtion에 빌트인으로 구현되어 있다. 먼저 메세지 함수는 메세지가 출발하는 노드 u, 메세지의 목적지 노드 v, edge e를 인자로 받는다. 메세지 함수의 이름도 같은 규칙으로 정해진다. 예를 들어 출발 노드의 피처와 목적지 피처를 더해 저장하는 함수는 u_add_v, 출발 노드와 edge의 피처를 곱하는 함수는 u_mul_e이다. 아래 코드는 소스 노드 hu, 목적지 노드 hv의 피처를 더한 결과를 엣지의 he 필드에 저장하는 코드이다.

dgl.function.u_add_v('hu', 'hv', 'he')

 

다음으로 aggregation을 진행하는 축약함수는 sum, max, min, mean 4가지 연산을 지원한다. 아래 코드는 sum 연산을 사용해서 메세지 m을 aggreation하는 코드이다.

dgl.function.sum('m', 'h')

 

 마지막으로 위 함수들을 이용해 메세지를 만들고, aggregation하고 노드의 feature를 업데이트 하는 update_all() 함수가 있다.

def update_all_example(graph):

    # store the result in graph.ndata['ft']
    graph.update_all(fn.u_mul_e('ft', 'a', 'm'),
                     fn.sum('m', 'ft'))
                     
    # Call update function outside of update_all
    final_ft = graph.ndata['ft'] * 2
    return final_ft

 위 코드는 출발 노드의 feature ft와 엣지 featrue a를 곱해서 메세지 m을 생성한다. 생성된 메세지 m들을 모두 더해서 새로운 피처 ft를 구하고, 여기에 2를 곱해서 타겟 노드의 새로운 피처 final_ft를 구한다. 아래 식에서는 위 코드의 동작을 나타내고 있다.

위 코드를 나타내는 공식

출처

DGL 사용자 가이드

https://docs.dgl.ai/guide_ko/graph.html

 

1장: 그래프 — DGL 1.0.2 documentation

 

docs.dgl.ai