PyG是一个基于PyTorch用与处理部规则数据(比如图)的库,是一个用于在图等数据上快速实现表征学习的框架,是当前最流行和广泛使用的GNN(Graph Neural Networks, GNN 图神经网络)库。
Graph Neural Networks,GNN,称为图神经网络,是深度学习中近年来比较受关注的领域,GNN通过对信息的传递、转换和聚合实现特征的提取,类似与传统的CNN,只是CNN只能处理规则的输入,如图像等输入的高、宽和通道数都是固定的,而GNN可以处理部规则的输入,如点云等。
安装
pip install torch-geometric
pip install torch-sparse
pip install torch-scatter
pip install pytorch-fid
torch_geometric.data.Data
节点和节点之间的边构成了图,在PyG中,构建图需要两个要素:节点和边。PyG提供了torch_geometric.data.Data(简称Data)用于构建图,包括5个属性,每一个属性都部是必须的,可以为空。
- x:用于存储每个节点的特征,形状是[num_nodes, num_node_features].
- edge_index:用于存储节点之间的边,形状是[2, num_edges]。
- pos:存储节点的坐标,形状是[num_nodes, num_dimensions]。
- y:存储样本标签。如果是每个节点都有标签,那么形状是[num_nodes, *];如果是整张图只有一个标签,那么形状是[1, *]。
- edge_attr:存储边的特征。形状是[num_edges, num_edge_features]。
Data对象不仅仅限制于这些属性,还可以通过data.face来扩展Data,以张量保存三维网格中三角形的连接性。
和P有Torch稍有不同,Data里包含了样本的label,在PyTorch中,重写Dataset的__getitem__(),根据index返回对应的样本和label。在PyG中,在get()函数中根据index返回torch_geometric.data.Data类型的数据,在Data里包含了数据和label。
例如:未加权无向图(未加权指边上没有权值),包括3个节点和4条边:(0->1),(1->0),(1->2),(2->1),每个节点都有一维特征。
import torch
from torch_geometric.data import Data
#由于是无向图,有四条边:(0->1),(1->0),(1->2),(2->1)
#方式一:常用方式,edge_index中边的存储方式有两个list,第一个list是边的起始点,第二个list是边的目标节点。
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
#节点的特征
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
# 方式二:需要先转置然后使用contiguous()方法。
edge_index = torch.tensor([[0, 1],[1, 0], [1, 2], [2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.t().contiguous())
例如:有向图有4个节点,每个节点有两个特征,有自己的类别标签。
import torch
from torch_geometric.data import Data
x = torch.tensor([[2, 1], [5, 6], [3, 7], [12, 0]], dtype=troch.float)
y = torch.tensor([0, 1, 0, 1], dtype=torch.float)
#与节点对应顺序无关,顺序怎么写都性
edge_index = torch.tensor([[0, 1, 2, 0, 3], [1, 0, 1, 3, 2]], dtype=torch.long)
data = Data(x=x, y=y, edge_index=edge_index)
Dataset与DataLoader
有了Data就可以创建自己的Dataset,读取并返回Data了。
自定义Dataset
尽管PyG包含了许多有用的数据集,也可以通过继承torch_geometric.data.Dataset使用自己的数据集。提供2种不同的Dataset
:
- InMemoryDataset:使用这个Dataset会一次性把数据全部加载到内存中。
- Dataset:使用这个Dataset每次加载一个数据到内存中,比较常用。
需要在自定义的Dataset的初始化方法中传入数据存放的路径root,然后PyG会在这个路径下再划分2个文件夹:
- raw_dir:存放原始数据的路径,一般是csv、mat等格式。
- processed_dir:存放处理后的数据,一般是pt格式,由重写process()方法实现。
除了root,类初始化的init函数还接收三个函数参数transform, pre_transform 和pre_filter,这些参数的默认值都是None。transform函数用于动态的转换数据对象。pre_transform函数在数据保存到硬盘之前进行一次转换。pre_filter用于过滤某些数据对象。
保存在内存中的数据集
为了创建InMemoryDataset,需要实现下面四个方法:
- raw_file_names():该函数返回文件名需要在raw_dir文件夹下找到才可以跳过下载过程。
- processed_file_names():该函数返回的文件名需要在processed_dir中找到才可以跳过处理过程。
- download():下载文件到raw_dir。
- process():处理原始数据并保存在processed_dir。
在process():函数中,需要读入并创建一个Data对象列表之后将所有Data类型的对象保存在processed_dir文件夹中。由于无法将全部数据保存到内存中,需要在数据固化之前通过collate()函数保存Data对象的索引,此外,该函数还会返回一个slices字典用于从本地重建单个样例对象。于是在数据集对象new的时候,需要从本地读取self.data和self.slices对象。
创建更大规模的数据集
有一些数据的规模太大,无法一次性加载到内存中,需要自己实现torch_geometric.data.Dataset,只需要额外实现两个方法:
- len():返回数据集的长度
- get():自定义加载Graph的方法
在PyTorch中,是没有raw和processed这两个文件夹的,这两个文件夹在PyG中的实际意义和处理逻辑。
torch_geometric.data.Dataset继承自torch.utils.data.Dataset,在初始化方法__init__()中,会调用_download()方法和_process()方法。
_download()方法如下,首先检查self.raw_paths列表中的文件是否存在;如果存在,则返回;如果不存在,则调用self.download()方法下载文件。
_process()方法如下,首先在self.processed_dir中有pre_transform,那么判断这个pre_transform和传进来的pre_transform是否一致,如果不一致,那么警告提示用户先删除self.processed_dir文件夹。pre_filter同理。
然后检查self.processed_paths列表中的文件是否存在;如果存在,则返回;如果不存在,则调用self.process()生成文件。
一般来说不用实现downloand()方法。
如果你直接把处理好的 pt 文件放在了self.processed_dir中,那么也不用实现process()方法。
在 Pytorch 的dataset中,需要实现__getitem__()方法,根据index返回样本和标签。在这里torch_geometric.data.Dataset中,重写了__getitem__()方法,其中调用了get()方法获取数据。
需要实现的是get()方法,根据index返回torch_geometric.data.Data类型的数据。
process()方法存在的意义是原始的格式可能是 csv 或者 mat,在process()函数里可以转化为 pt 格式的文件,这样在get()方法中就可以直接使用torch.load()函数读取 pt 格式的文件,返回的是torch_geometric.data.Data类型的数据,而不用在get()方法做数据转换操作 (把其他格式的数据转换为 torch_geometric.data.Data类型的数据)。当然也可以提前把数据转换为 torch_geometric.data.Data类型,使用 pt 格式保存在self.processed_dir中。
#torch_geometric/data/dataset.py
from typing import List, Optional, Callable, Union, Any, Tuple
import sys
import re
import copy
import warnings
import numpy as np
import os.path as osp
from collections.abc import Sequence
import torch.utils.data
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.data.makedirs import makedirs
IndexType = Union[slice, Tensor, np.ndarray, Sequence]
class Dataset(torch.utils.data.Dataset):
r"""Dataset base class for creating graph datasets.
See `here <https://pytorch-geometric.readthedocs.io/en/latest/notes/
create_dataset.html>`__ for the accompanying tutorial.
Args:
root (string, optional): Root directory where the dataset should be
saved. (optional: :obj:`None`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
"""
@property
def raw_file_names(self) -> Union[str, List[str], Tuple]:
r"""The name of the files in the :obj:`self.raw_dir` folder that must
be present in order to skip downloading."""
raise NotImplementedError
@property
def processed_file_names(self) -> Union[str, List[str], Tuple]:
r"""The name of the files in the :obj:`self.processed_dir` folder that
must be present in order to skip processing."""
raise NotImplementedError
def download(self):
r"""Downloads the dataset to the :obj:`self.raw_dir` folder."""
raise NotImplementedError
def process(self):
r"""Processes the dataset to the :obj:`self.processed_dir` folder."""
raise NotImplementedError
def len(self) -> int:
r"""Returns the number of graphs stored in the dataset."""
raise NotImplementedError
def get(self, idx: int) -> Data:
r"""Gets the data object at index :obj:`idx`."""
raise NotImplementedError
def __init__(self, root: Optional[str] = None,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None):
super().__init__()
if isinstance(root, str):
root = osp.expanduser(osp.normpath(root))
self.root = root
self.transform = transform
self.pre_transform = pre_transform
self.pre_filter = pre_filter
self._indices: Optional[Sequence] = None
if 'download' in self.__class__.__dict__:
self._download()
if 'process' in self.__class__.__dict__:
self._process()
def indices(self) -> Sequence:
return range(self.len()) if self._indices is None else self._indices
@property
def raw_dir(self) -> str:
return osp.join(self.root, 'raw')
@property
def processed_dir(self) -> str:
return osp.join(self.root, 'processed')
@property
def num_node_features(self) -> int:
r"""Returns the number of features per node in the dataset."""
data = self[0]
data = data[0] if isinstance(data, tuple) else data
if hasattr(data, 'num_node_features'):
return data.num_node_features
raise AttributeError(f"'{data.__class__.__name__}' object has no "
f"attribute 'num_node_features'")
@property
def num_features(self) -> int:
r"""Returns the number of features per node in the dataset.
Alias for :py:attr:`~num_node_features`."""
return self.num_node_features
@property
def num_edge_features(self) -> int:
r"""Returns the number of features per edge in the dataset."""
data = self[0]
data = data[0] if isinstance(data, tuple) else data
if hasattr(data, 'num_edge_features'):
return data.num_edge_features
raise AttributeError(f"'{data.__class__.__name__}' object has no "
f"attribute 'num_edge_features'")
@property
def raw_paths(self) -> List[str]:
r"""The absolute filepaths that must be present in order to skip
downloading."""
files = to_list(self.raw_file_names)
return [osp.join(self.raw_dir, f) for f in files]
@property
def processed_paths(self) -> List[str]:
r"""The absolute filepaths that must be present in order to skip
processing."""
files = to_list(self.processed_file_names)
return [osp.join(self.processed_dir, f) for f in files]
def _download(self):
if files_exist(self.raw_paths): # pragma: no cover
return
makedirs(self.raw_dir)
self.download()
def _process(self):
f = osp.join(self.processed_dir, 'pre_transform.pt')
if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
warnings.warn(
f"The `pre_transform` argument differs from the one used in "
f"the pre-processed version of this dataset. If you want to "
f"make use of another pre-processing technique, make sure to "
f"sure to delete '{self.processed_dir}' first")
f = osp.join(self.processed_dir, 'pre_filter.pt')
if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
warnings.warn(
"The `pre_filter` argument differs from the one used in the "
"pre-processed version of this dataset. If you want to make "
"use of another pre-fitering technique, make sure to delete "
"'{self.processed_dir}' first")
if files_exist(self.processed_paths): # pragma: no cover
return
print('Processing...', file=sys.stderr)
makedirs(self.processed_dir)
self.process()
path = osp.join(self.processed_dir, 'pre_transform.pt')
torch.save(_repr(self.pre_transform), path)
path = osp.join(self.processed_dir, 'pre_filter.pt')
torch.save(_repr(self.pre_filter), path)
print('Done!', file=sys.stderr)
def __len__(self) -> int:
r"""The number of examples in the dataset."""
return len(self.indices())
def __getitem__(
self,
idx: Union[int, np.integer, IndexType],
) -> Union['Dataset', Data]:
r"""In case :obj:`idx` is of type integer, will return the data object
at index :obj:`idx` (and transforms it in case :obj:`transform` is
present).
In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a
tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type long or
bool, will return a subset of the dataset at the specified indices."""
if (isinstance(idx, (int, np.integer))
or (isinstance(idx, Tensor) and idx.dim() == 0)
or (isinstance(idx, np.ndarray) and np.isscalar(idx))):
data = self.get(self.indices()[idx])
data = data if self.transform is None else self.transform(data)
return data
else:
return self.index_select(idx)
def index_select(self, idx: IndexType) -> 'Dataset':
r"""Creates a subset of the dataset from specified indices :obj:`idx`.
Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a
list, a tuple, or a :obj:`torch.Tensor` or :obj:`np.ndarray` of type
long or bool."""
indices = self.indices()
if isinstance(idx, slice):
indices = indices[idx]
elif isinstance(idx, Tensor) and idx.dtype == torch.long:
return self.index_select(idx.flatten().tolist())
elif isinstance(idx, Tensor) and idx.dtype == torch.bool:
idx = idx.flatten().nonzero(as_tuple=False)
return self.index_select(idx.flatten().tolist())
elif isinstance(idx, np.ndarray) and idx.dtype == np.int64:
return self.index_select(idx.flatten().tolist())
elif isinstance(idx, np.ndarray) and idx.dtype == np.bool:
idx = idx.flatten().nonzero()[0]
return self.index_select(idx.flatten().tolist())
elif isinstance(idx, Sequence) and not isinstance(idx, str):
indices = [indices[i] for i in idx]
else:
raise IndexError(
f"Only slices (':'), list, tuples, torch.tensor and "
f"np.ndarray of dtype long or bool are valid indices (got "
f"'{type(idx).__name__}')")
dataset = copy.copy(self)
dataset._indices = indices
return dataset
def shuffle(
self,
return_perm: bool = False,
) -> Union['Dataset', Tuple['Dataset', Tensor]]:
r"""Randomly shuffles the examples in the dataset.
Args:
return_perm (bool, optional): If set to :obj:`True`, will also
return the random permutation used to shuffle the dataset.
(default: :obj:`False`)
"""
perm = torch.randperm(len(self))
dataset = self.index_select(perm)
return (dataset, perm) if return_perm is True else dataset
def __repr__(self) -> str:
arg_repr = str(len(self)) if len(self) > 1 else ''
return f'{self.__class__.__name__}({arg_repr})'
def to_list(value: Any) -> Sequence:
if isinstance(value, Sequence) and not isinstance(value, str):
return value
else:
return [value]
def files_exist(files: List[str]) -> bool:
# NOTE: We return `False` in case `files` is empty, leading to a
# re-processing of files on every instantiation.
return len(files) != 0 and all([osp.exists(f) for f in files])
def _repr(obj: Any) -> str:
if obj is None:
return 'None'
return re.sub('(<.*?)\\s.*(>)', r'\1\2', obj.__repr__())
DataLoader
通过torch_geometric.data.DataLoader可以方便地使用 mini-batch。
dataset = get_dataset(train_args['dataset'], 'test')
dataloader = DataLoader(dataset,
batch_size=args.batch_size,
num_workers=1,
pin_memory=True,
shuffle=False)
PyG实现LayoutGAN
参数设置
数据处理
模型定义
Generator(
(fc_z): Linear(in_features=4, out_features=128, bias=True)
(emb_label): Embedding(13, 128)
(fc_in): Linear(in_features=256, out_features=256, bias=True)
(transformer): TransformerEncoder(
(layers): ModuleList(
(0): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=128, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
(1): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=128, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
(2): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=128, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
(3): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=128, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
(4): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=128, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
(5): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=128, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
(6): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=128, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
(7): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=128, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
)
)
)
TransformerEncoder(
(layers): ModuleList(
(0): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=128, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
(1): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=128, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
(2): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=128, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
(3): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=128, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
(4): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=128, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
(5): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=128, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
(6): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=128, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
(7): TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=128, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
)
)
TransformerEncoderLayer(
(self_attn): MultiheadAttention(
(out_proj): _LinearWithBias(in_features=256, out_features=256, bias=True)
)
(linear1): Linear(in_features=256, out_features=128, bias=True)
(dropout): Dropout(p=0.1, inplace=False)
(linear2): Linear(in_features=128, out_features=256, bias=True)
(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
(dropout1): Dropout(p=0.1, inplace=False)
(dropout2): Dropout(p=0.1, inplace=False)
)
预测结果
参考资料
图神经网络 PyTorch Geometric 入门教程
pyg-team/pytorch_geometric
PyTorch中的contiguous
Pytorch Geometric 3 - 自定义数据集
图神经网络:PyTorch geometric使用
pytorch-geometric 从入门到不放弃 day2
【NLP】Transformer模型原理详解
Seq2Seq模型概述
计算机视觉中attention机制的理解
计算机视觉中的注意力机制
nlp中的Attention注意力机制+Transformer详解
深度学习中的注意力模型(2017版)
NLP中的Attention原理和源码解析
深度学习之seq2seq模型以及Attention机制
从Seq2seq到Attention模型到Self Attention
真正的完全图解Seq2Seq Attention模型
Pytorch-seq2seq机器翻译模型+attention
【NLP】Transformer模型原理详解
举个例子讲下transformer的输入输出细节及其他
This post is all you need(层层剥开Transformer)