Published at ICLR 19'

Arxiv: 1806.09055

Code: Code URL

Abstract

将nas问题的搜索空间进行连续松弛,使得搜索空间连续化,即将问题转化为了可微的问题,进而使用梯度下降法进行优化,取得了速度优势. ## Intro 根本难点之一: 架构搜索被建模为离散空间上的黑箱优化问题,导致离散空间的优化很困难将搜索空间放宽到连续空间,这样可以使用梯度下降法,提升了搜索速度

Darts对于整体空间进行连续化假设,并不限制于简单的conv filter或者分支结构,也不受限于特定的结构限制,即适用于卷积网络和循环网络.

DARTS

搜索空间建模

为每个模块寻找一个共同的essential building cell,cell顺序化堆叠即可获得卷积网络,循环堆叠获得RNN.

每个cell建模为有\(N\)个节点的有向无环图,节点代表着隐式表示(\(x^{(i)}\)), 边\((i,j)\)代表了\(i\)\(j\)的某种操作\(o(i,j)\),

假定cell有两个输入一个输出,对于卷积网络,输入节点为上一层的输出及上上层的输出,对于循环网络,输入节点为上一层的隐状态和上一层的输出,为了统一维度,中间有一些降维操作\(x^{(j)} = \sum_{i<j}o^{(i,j)}(x^{(i)})\). 边的操作包含特殊的0操作(无链接)

这样将模块内建模为对DAG中边的学习

连续松弛

\(\mathcal{O}\)代表着可能的操作集合: \[\bar{o}^{i,j}(x)=\sum_{o\in\mathcal{O}}\frac{\exp(\alpha_{o}^{i,j})}{\sum_{o^{\prime}\in\mathcal(O)}\exp(\alpha_{o^{\prime}}^{i,j})}o(x)\]

可以看到这里引入了新的符号\(\alpha_{o}^{i,j}\),其含义为:第\(i\)个特征图到第\(j\)个特征图之间的操作 \(o(i,j)\)的权重.这也是我们之后需要搜索的架构参数.

举个例子,如果这个操作的权重\(\alpha_{o}^{i,j}=0\),那么就可以认为我们完全不需要这个操作. \(\alpha\)对应着每个操作的权重,一旦\(\alpha\)确定,网络也就确定了

决定网络performance的两个因素:

结构建模: \(\alpha\)和网络权重: \(\omega\)

\(\omega\)是随着\(\alpha\)确定的,即inner-loop优化\(\omega\),outer-loop优化\(\alpha\)

问题转化为:

\[\min_{\alpha}\quad\mathcal{L}_{val}(\omega^{\star}(\alpha),\alpha)\]

\[s.t.\qquad \omega^{\star}(\alpha) = \arg\min_{\omega} \mathcal{L}_{train}(\omega,\alpha)\]

实际上,交替进行: - 训练集上梯度下降优化权重 - val集上梯度下降优化网络结构

近似计算优化

一旦\(\alpha\)发生了更新,理应更新对应的\(\omega\),这样会带来巨大的计算成本.

先做近似: \[\nabla_{\alpha} \mathcal{L}_{val}(w^*(\alpha), \alpha)\approx \nabla_{\alpha} \mathcal{L}_{val}(w - \xi \nabla_{w} \mathcal{L}_{train}(w, \alpha), \alpha)\]

根据链式法则:

\[\frac{\partial f}{\partial x} = \frac{\partial f}{\partial g_1}\cdot\frac{\partial g_1}{\partial x} + \frac{\partial f}{\partial g_2}\cdot \frac{\partial g_2}{\partial x}+ \cdots +\frac{\partial f}{\partial g_n}\cdot \frac{\partial g_n}{\partial x} \]

问题转化为化简:

\[\nabla_\alpha \mathcal{L}_{val}(w', \alpha) - \xi \nabla^2_{\alpha, w} \mathcal{L}_{train}(w, \alpha) \nabla_{w'} \mathcal{L}_{val}(w', \alpha)\]

Thinkings

  1. 如何进行连续化建模是使得搜索空间可微的关键点之一,应该不限于超参数搜索,而是结构搜索也很关键,结构搜索如何连续化?

  2. 超参数搜索如何连续化?

  3. 本身就是一个经典的bi-level optimization问题,关键还是问题的建模方式,对于一个问题采用合适的建模方式是解决问题的第零步