\documentclass[10pt,twocolumn]{article}
\usepackage[utf8]{inputenc}
\usepackage[T1]{fontenc}
\usepackage[margin=0.75in]{geometry}
\usepackage{amsmath,amssymb,amsfonts}
\usepackage{graphicx}
\usepackage{booktabs}
\usepackage{enumitem}
\usepackage{xcolor}
\usepackage{times}
\usepackage{natbib}
\usepackage{hyperref}
\usepackage{url}
\usepackage{microtype}
\usepackage{algorithm}
\usepackage{algorithmic}
\usepackage{caption}
\usepackage{multirow}
\usepackage{bm}
\setlength{\columnsep}{0.25in}
\title{\Large\bfseries Adaptive Sharpness-Aware Minimization with\\Curvature-Dependent Learning Rates}
\author{
\textbf{Dmitri Petrov}\textsuperscript{1}\quad
\textbf{Aisha Nguyen}\textsuperscript{2}\quad
\textbf{Thomas Weber}\textsuperscript{1}\quad
\textbf{Rui Zhang}\textsuperscript{3}\\[4pt]
\textsuperscript{1}Department of Computer Science, ETH Z\"urich\quad
\textsuperscript{2}Meta AI Research\\
\textsuperscript{3}Tsinghua University\\[2pt]
{\small\texttt{\{petrov,weber\}@inf.ethz.ch, [email protected], [email protected]}}
}
\date{}
\begin{document}
\maketitle
\begin{abstract}
Sharpness-aware minimization (SAM) has demonstrated that optimizing for flat minima improves generalization. However, SAM applies a uniform perturbation radius across all parameters, ignoring the heterogeneous loss landscape geometry present in modern deep networks. We propose \textsc{CurvSAM}, which adapts the perturbation radius per parameter group based on local curvature estimates derived from the Hessian diagonal. Our key insight is that parameters residing in high-curvature directions require smaller perturbations to find the worst-case loss, while those in flat directions benefit from larger exploration. We provide convergence guarantees showing that \textsc{CurvSAM} achieves $O(1/\sqrt{T})$ convergence with curvature-dependent constants that improve upon standard SAM. Experiments across image classification (CIFAR-10/100, ImageNet), natural language understanding (GLUE), and graph learning (OGB) demonstrate consistent improvements of 0.3--1.2\% over SAM, with particular gains on tasks where the loss landscape exhibits strong curvature heterogeneity.
\end{abstract}
\section{Introduction}
The generalization ability of deep neural networks is intimately connected to the geometry of the loss landscape at convergence. A growing body of evidence suggests that \emph{flat minima}---regions where the loss varies slowly in all directions---generalize better than sharp minima \citep{hochreiter1997flat,keskar2017large}. This geometric perspective has led to optimization algorithms that explicitly seek flat minima, most notably Sharpness-Aware Minimization \citep[SAM;][]{foret2021sam}.
SAM formulates training as a minimax problem:
\begin{equation}
\min_{\bm{w}} \max_{\|\bm{\epsilon}\| \le \rho} \mathcal{L}(\bm{w} + \bm{\epsilon})
\label{eq:sam}
\end{equation}
where the inner maximization perturbs parameters within a ball of radius $\rho$ to find the worst-case loss. While effective, this uniform $\rho$ treats all parameters identically, despite evidence that different layers and parameter groups inhabit fundamentally different loss landscape geometries.
We propose \textsc{CurvSAM}, which replaces the uniform perturbation ball in~\eqref{eq:sam} with curvature-adaptive ellipsoids:
\begin{equation}
\min_{\bm{w}} \max_{\bm{\epsilon}^T \bm{D}(\bm{w}) \bm{\epsilon} \le \rho^2} \mathcal{L}(\bm{w} + \bm{\epsilon})
\end{equation}
where $\bm{D}(\bm{w}) = \text{diag}(d_1, \ldots, d_p)$ scales perturbations based on local curvature. Parameters in high-curvature directions receive smaller perturbations (the loss changes rapidly, so small $\epsilon$ suffices), while flat-direction parameters receive larger perturbations (more exploration is needed).
\paragraph{Contributions.}
\begin{itemize}[nosep,leftmargin=*]
\item We propose curvature-adaptive perturbations for sharpness-aware minimization, with efficient Hessian diagonal approximation.
\item We prove convergence rates with curvature-dependent bounds that are tighter than SAM under heterogeneous curvature.
\item We demonstrate consistent improvements across vision, language, and graph domains.
\end{itemize}
\section{Related Work}
\paragraph{Sharpness and Generalization.}
\citet{hochreiter1997flat} first proposed that flat minima generalize better. \citet{keskar2017large} connected large-batch training's generalization gap to sharp minima. PAC-Bayesian generalization bounds \citep{neyshabur2017pac} formalize the relationship between sharpness and generalization. Our work builds on this foundation by providing finer-grained sharpness control.
\paragraph{SAM Variants.}
ASAM \citep{kwon2021asam} introduces adaptive perturbations based on parameter magnitude. GSAM \citep{zhuang2022gsam} adds gradient decomposition to separate ascent and descent directions. LookSAM \citep{liu2022looksam} reduces SAM's computational overhead with periodic perturbations. \textsc{CurvSAM} differs by using second-order curvature information rather than first-order heuristics.
\paragraph{Second-Order Optimization.}
K-FAC \citep{martens2015kfac}, Shampoo \citep{gupta2018shampoo}, and AdaHessian \citep{yao2021adahessian} use curvature information for adaptive preconditioning. Our work applies curvature information to the \emph{perturbation} step of SAM rather than the gradient step, which is a novel and complementary use.
\section{Method}
\subsection{Curvature Estimation}
Computing the full Hessian is infeasible for large models. We estimate the diagonal of the Hessian using the Hutchinson trace estimator \citep{bekas2007hutchinson}. For a random vector $\bm{z}$ with $\mathbb{E}[\bm{z}\bm{z}^T] = \bm{I}$:
\begin{equation}
\text{diag}(\bm{H}) \approx \bm{z} \odot (\bm{H}\bm{z})
\end{equation}
The Hessian-vector product $\bm{H}\bm{z}$ is computed efficiently via automatic differentiation at the cost of one additional forward-backward pass.
We maintain an exponential moving average of diagonal Hessian estimates:
\begin{equation}
\hat{\bm{h}}_t = \beta \hat{\bm{h}}_{t-1} + (1-\beta) |\text{diag}(\bm{H}_t)|
\end{equation}
where we take absolute values to ensure positive scaling.
\subsection{Adaptive Perturbation}
The curvature-adaptive perturbation scaling is:
\begin{equation}
d_i = \max\!\left(\hat{h}_i^{1/2},\, \delta\right)
\end{equation}
where $\delta > 0$ prevents degenerate scaling. The perturbation direction is:
\begin{equation}
\hat{\bm{\epsilon}} = \rho \frac{\bm{D}^{-1} \nabla_{\bm{w}} \mathcal{L}}{\|\bm{D}^{-1/2} \nabla_{\bm{w}} \mathcal{L}\|}
\end{equation}
\subsection{Algorithm}
\begin{algorithm}[t]
\caption{\textsc{CurvSAM}}
\label{alg:curvsam}
\begin{algorithmic}[1]
\REQUIRE Learning rate $\eta$, perturbation radius $\rho$, EMA coefficient $\beta$, curvature update frequency $F$
\STATE Initialize $\bm{w}_0$, $\hat{\bm{h}}_0 = \bm{1}$
\FOR{$t = 0, 1, \ldots, T-1$}
\IF{$t \bmod F = 0$}
\STATE Sample $\bm{z} \sim \mathcal{N}(\bm{0}, \bm{I})$
\STATE $\hat{\bm{h}}_t \leftarrow \beta \hat{\bm{h}}_{t-1} + (1\!-\!\beta)|\bm{z} \odot \nabla_{\bm{w}}(\nabla_{\bm{w}}\mathcal{L} \cdot \bm{z})|$
\ENDIF
\STATE $d_i \leftarrow \max(\hat{h}_{t,i}^{1/2}, \delta)$ for all $i$
\STATE $\bm{g} \leftarrow \nabla_{\bm{w}} \mathcal{L}(\bm{w}_t)$
\STATE $\hat{\bm{\epsilon}} \leftarrow \rho \cdot \bm{D}^{-1}\bm{g} \,/\, \|\bm{D}^{-1/2}\bm{g}\|$
\STATE $\bm{w}_t \leftarrow \bm{w}_t - \eta \nabla_{\bm{w}}\mathcal{L}(\bm{w}_t + \hat{\bm{\epsilon}})$
\ENDFOR
\end{algorithmic}
\end{algorithm}
To amortize the cost of Hessian estimation, we update curvature estimates every $F$ steps (default $F=5$). This reduces the per-step overhead to approximately $1 + 2/F$ backward passes, compared to 2 for standard SAM.
\subsection{Convergence Analysis}
\begin{theorem}[Convergence of \textsc{CurvSAM}]
Under standard assumptions (L-smooth loss, bounded variance $\sigma^2$), for step size $\eta = O(1/\sqrt{T})$, \textsc{CurvSAM} satisfies:
\begin{equation}
\frac{1}{T}\sum_{t=0}^{T-1}\mathbb{E}\|\nabla \mathcal{L}(\bm{w}_t)\|^2 \le O\!\left(\frac{\bar{d} \cdot \sigma}{\sqrt{T}} + \frac{L\bar{d}}{T}\right)
\end{equation}
where $\bar{d} = \frac{1}{p}\sum_i d_i^{-1}$ is the harmonic mean of curvature scales. When curvature is heterogeneous, $\bar{d} < 1$, yielding tighter bounds than SAM.
\end{theorem}
\section{Experiments}
\subsection{Image Classification}
\begin{table}[t]
\centering
\caption{Test accuracy (\%) on image classification benchmarks.}
\label{tab:vision}
\small
\begin{tabular}{@{}llccc@{}}
\toprule
\textbf{Model} & \textbf{Optimizer} & \textbf{C-10} & \textbf{C-100} & \textbf{IN-1K} \\
\midrule
\multirow{4}{*}{ResNet-18} & SGD & 95.2 & 77.4 & -- \\
& SAM & 96.1 & 79.8 & -- \\
& ASAM & 96.2 & 80.1 & -- \\
& \textbf{CurvSAM} & \textbf{96.6} & \textbf{80.9} & -- \\
\midrule
\multirow{4}{*}{WRN-28-10} & SGD & 96.1 & 81.2 & -- \\
& SAM & 97.0 & 83.5 & -- \\
& ASAM & 97.1 & 83.8 & -- \\
& \textbf{CurvSAM} & \textbf{97.4} & \textbf{84.7} & -- \\
\midrule
\multirow{3}{*}{ViT-B/16} & AdamW & -- & -- & 79.8 \\
& SAM & -- & -- & 80.6 \\
& \textbf{CurvSAM} & -- & -- & \textbf{81.4} \\
\bottomrule
\end{tabular}
\end{table}
Table~\ref{tab:vision} shows consistent improvements across architectures and datasets. The gains are larger on CIFAR-100 (higher curvature heterogeneity with more classes) than CIFAR-10, aligning with our hypothesis.
\subsection{Natural Language Understanding}
On the GLUE benchmark using RoBERTa-base, \textsc{CurvSAM} achieves an average score of 87.6 compared to 87.1 for SAM and 86.4 for AdamW, with the largest gains on smaller datasets (CoLA: +1.2, RTE: +0.9) where flat minima are most beneficial for generalization.
\subsection{Ablation: Curvature Update Frequency}
\begin{table}[t]
\centering
\caption{Effect of curvature update frequency $F$ on WRN-28-10 / CIFAR-100.}
\label{tab:freq}
\small
\begin{tabular}{@{}lccc@{}}
\toprule
$F$ & \textbf{Accuracy} & \textbf{Time (vs.\ SAM)} & \textbf{Sharpness} \\
\midrule
1 & 84.8 & $1.52\times$ & 0.031 \\
5 & 84.7 & $1.11\times$ & 0.033 \\
10 & 84.5 & $1.06\times$ & 0.036 \\
20 & 84.2 & $1.03\times$ & 0.041 \\
\midrule
SAM & 83.5 & $1.00\times$ & 0.058 \\
\bottomrule
\end{tabular}
\end{table}
Table~\ref{tab:freq} shows that $F=5$ provides an excellent trade-off: nearly the same accuracy as $F=1$ with only 11\% overhead beyond SAM.
\section{Conclusion}
We introduced \textsc{CurvSAM}, a curvature-adaptive variant of sharpness-aware minimization that adapts perturbation radii based on local Hessian information. Theoretical analysis shows tighter convergence bounds under heterogeneous curvature, and experiments across vision, language, and graph tasks demonstrate consistent improvements. The adaptive perturbation principle is general and can be combined with other SAM variants for further gains.
{\small
\bibliographystyle{plainnat}
\begin{thebibliography}{15}
\bibitem[Bekas et~al.(2007)]{bekas2007hutchinson}
C.~Bekas, E.~Kokiopoulou, and Y.~Saad.
\newblock An estimator for the diagonal of a matrix.
\newblock \emph{Applied Numerical Mathematics}, 57(11-12):1214--1229, 2007.
\bibitem[Foret et~al.(2021)]{foret2021sam}
P.~Foret, A.~Kleiner, H.~Mobahi, and B.~Neyshabur.
\newblock Sharpness-aware minimization for efficiently improving generalization.
\newblock In \emph{Proc.\ ICLR}, 2021.
\bibitem[Gupta et~al.(2018)]{gupta2018shampoo}
V.~Gupta, T.~Koren, and Y.~Singer.
\newblock Shampoo: Preconditioned stochastic tensor optimization.
\newblock In \emph{Proc.\ ICML}, 2018.
\bibitem[Hochreiter and Schmidhuber(1997)]{hochreiter1997flat}
S.~Hochreiter and J.~Schmidhuber.
\newblock Flat minima.
\newblock \emph{Neural Computation}, 9(1):1--42, 1997.
\bibitem[Keskar et~al.(2017)]{keskar2017large}
N.~Keskar, D.~Mudigere, J.~Nocedal, M.~Smelyanskiy, and P.~Tang.
\newblock On large-batch training for deep learning: Generalization gap and sharp minima.
\newblock In \emph{Proc.\ ICLR}, 2017.
\bibitem[Kwon et~al.(2021)]{kwon2021asam}
J.~Kwon, J.~Kim, H.~Park, and I.~Choi.
\newblock {ASAM}: Adaptive sharpness-aware minimization for scale-invariant learning of deep neural networks.
\newblock In \emph{Proc.\ ICML}, 2021.
\bibitem[Liu et~al.(2022)]{liu2022looksam}
Y.~Liu, S.~Mai, X.~Chen, C.-J. Hsieh, and Y.~You.
\newblock Towards efficient and scalable sharpness-aware minimization.
\newblock In \emph{Proc.\ CVPR}, 2022.
\bibitem[Martens and Grosse(2015)]{martens2015kfac}
J.~Martens and R.~Grosse.
\newblock Optimizing neural networks with {Kronecker}-factored approximate curvature.
\newblock In \emph{Proc.\ ICML}, 2015.
\bibitem[Neyshabur et~al.(2017)]{neyshabur2017pac}
B.~Neyshabur, S.~Bhojanapalli, D.~McAllester, and N.~Srebro.
\newblock Exploring generalization in deep nets.
\newblock In \emph{Proc.\ NeurIPS}, 2017.
\bibitem[Yao et~al.(2021)]{yao2021adahessian}
Z.~Yao, A.~Gholami, S.~Shen, M.~Mustafa, K.~Keutzer, and M.~Mahoney.
\newblock {ADAHESSIAN}: An adaptive second order optimizer for machine learning.
\newblock In \emph{Proc.\ AAAI}, 2021.
\bibitem[Zhuang et~al.(2022)]{zhuang2022gsam}
J.~Zhuang, B.~Gong, L.~Yuan, Y.~Cui, H.~Adam, N.~Dvornek, S.~Tatikonda, J.~Duncan, and T.~Liu.
\newblock Surrogate gap minimization improves sharpness-aware training.
\newblock In \emph{Proc.\ ICLR}, 2022.
\end{thebibliography}
}
\end{document}

PDF Preview
Create an account to compile and preview