Noise Stability Optimization for Flat Minima with Optimal Convergence Rates
We consider finding flat, local minimizers by adding average weight perturbations. Given a nonconvex function f: ℝ^d →ℝ and a d-dimensional distribution 𝒫 which is symmetric at zero, we perturb the weight of f and define F(W) = 𝔼[f(W + U)], where U is a random sample from 𝒫. This injection induces regularization through the Hessian trace of f for small, isotropic Gaussian perturbations. Thus, the weight-perturbed function biases to minimizers with low Hessian trace. Several prior works have studied settings related to this weight-perturbed function by designing algorithms to improve generalization. Still, convergence rates are not known for finding minima under the average perturbations of the function F. This paper considers an SGD-like algorithm that injects random noise before computing gradients while leveraging the symmetry of 𝒫 to reduce variance. We then provide a rigorous analysis, showing matching upper and lower bounds of our algorithm for finding an approximate first-order stationary point of F when the gradient of f is Lipschitz-continuous. We empirically validate our algorithm for several image classification tasks with various architectures. Compared to sharpness-aware minimization, we note a 12.6 eigenvalue of the found minima, respectively, averaged over eight datasets. Ablation studies validate the benefit of the design of our algorithm.
READ FULL TEXT