All Projects → davda54 → Sam

davda54 / Sam

Licence: mit
SAM: Sharpness-Aware Minimization (PyTorch)

Programming Languages

python
139335 projects - #7 most used programming language

Projects that are alternatives of or similar to Sam

sam.pytorch
A PyTorch implementation of Sharpness-Aware Minimization for Efficiently Improving Generalization
Stars: ✭ 96 (-70.19%)
Mutual labels:  sam, optimizer
Aws Cognito Apigw Angular Auth
A simple/sample AngularV4-based web app that demonstrates different API authentication options using Amazon Cognito and API Gateway with an AWS Lambda and Amazon DynamoDB backend that stores user details in a complete end to end Serverless fashion.
Stars: ✭ 278 (-13.66%)
Mutual labels:  sam
falcon
A WordPress cleanup and performance optimization plugin.
Stars: ✭ 17 (-94.72%)
Mutual labels:  optimizer
soar
SQL Optimizer And Rewriter
Stars: ✭ 7,786 (+2318.01%)
Mutual labels:  optimizer
goga
Go evolutionary algorithm is a computer library for developing evolutionary and genetic algorithms to solve optimisation problems with (or not) many constraints and many objectives. Also, a goal is to handle mixed-type representations (reals and integers).
Stars: ✭ 39 (-87.89%)
Mutual labels:  optimizer
pheniqs
Fast and accurate sequence demultiplexing
Stars: ✭ 14 (-95.65%)
Mutual labels:  sam
madam
👩 Pytorch and Jax code for the Madam optimiser.
Stars: ✭ 46 (-85.71%)
Mutual labels:  optimizer
Adamp
AdamP: Slowing Down the Slowdown for Momentum Optimizers on Scale-invariant Weights (ICLR 2021)
Stars: ✭ 306 (-4.97%)
Mutual labels:  optimizer
faaskit
A lightweight middleware framework for functions as a service
Stars: ✭ 24 (-92.55%)
Mutual labels:  sam
sam
SAM: Software Automatic Mouth (Ported from https://github.com/vidarh/SAM)
Stars: ✭ 33 (-89.75%)
Mutual labels:  sam
gfsopt
Convenient hyperparameter optimization
Stars: ✭ 12 (-96.27%)
Mutual labels:  optimizer
pigosat
Go (golang) bindings for Picosat, the satisfiability solver
Stars: ✭ 15 (-95.34%)
Mutual labels:  optimizer
nibbler
Runtime Python bytecode optimizer. ⚡️
Stars: ✭ 21 (-93.48%)
Mutual labels:  optimizer
simplesam
Simple pure Python SAM parser and objects for working with SAM records
Stars: ✭ 50 (-84.47%)
Mutual labels:  sam
Lookahead.pytorch
lookahead optimizer (Lookahead Optimizer: k steps forward, 1 step back) for pytorch
Stars: ✭ 279 (-13.35%)
Mutual labels:  optimizer
Windows11-Optimization
Community repository, to improve security and performance of Windows 10 and windows 11 with tweaks, commands, scripts, registry keys, configuration, tutorials and more
Stars: ✭ 17 (-94.72%)
Mutual labels:  optimizer
simplu3D
A library to generate buildings from local urban regulations.
Stars: ✭ 18 (-94.41%)
Mutual labels:  optimizer
Booster
🚀Optimizer for mobile applications
Stars: ✭ 3,741 (+1061.8%)
Mutual labels:  optimizer
A
A graphical text editor
Stars: ✭ 280 (-13.04%)
Mutual labels:  sam
BioD
A D library for computational biology and bioinformatics
Stars: ✭ 45 (-86.02%)
Mutual labels:  sam

SAM Optimizer

Sharpness-Aware Minimization for Efficiently Improving Generalization

~ in Pytorch ~



SAM simultaneously minimizes loss value and loss sharpness. In particular, it seeks parameters that lie in neighborhoods having uniformly low loss. SAM improves model generalization and yields SoTA performance for several datasets. Additionally, it provides robustness to label noise on par with that provided by SoTA procedures that specifically target learning with noisy labels.

This is an unofficial repository for Sharpness-Aware Minimization for Efficiently Improving Generalization. Implementation-wise, SAM class is a light wrapper that computes the regularized "sharpness-aware" gradient, which is used by the underlying optimizer (such as SGD with momentum). This repository also includes a simple WRN for Cifar10; as a proof-of-concept, it beats the performance of SGD with momentum on this dataset.

Loss landscape with and without SAM

ResNet loss landscape at the end of training with and without SAM. Sharpness-aware updates lead to a significantly wider minimum, which then leads to better generalization properties.


Usage

It should be straightforward to use SAM in your training pipeline. Just keep in mind that the training will run twice as slow, because SAM needs two forward-backward passes to estime the "sharpness-aware" gradient. If you're using gradient clipping, make sure to change only the magnitude of gradients, not their direction.

from sam import SAM
...

model = YourModel()
base_optimizer = torch.optim.SGD  # define an optimizer for the "sharpness-aware" update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)
...

for input, output in data:

  # first forward-backward pass
  loss = loss_function(output, model(input))  # use this loss for any training statistics
  loss.backward()
  optimizer.first_step(zero_grad=True)
  
  # second forward-backward pass
  loss_function(output, model(input)).backward()  # make sure to do a full forward pass
  optimizer.second_step(zero_grad=True)
...

Alternative usage with a single closure-based step function. This alternative offers similar API to native PyTorch optimizers like LBFGS (kindly suggested by @rmcavoy):

from sam import SAM
...

model = YourModel()
base_optimizer = torch.optim.SGD  # define an optimizer for the "sharpness-aware" update
optimizer = SAM(model.parameters(), base_optimizer, lr=0.1, momentum=0.9)
...

for input, output in data:
  def closure():
    loss = loss_function(output, model(input))
    loss.backward()
    return loss

  loss = loss_function(output, model(input))
  loss.backward()
  optimizer.step(closure)
  optimizer.zero_grad()
...

Documentation

SAM.__init__

Argument Description
params (iterable) iterable of parameters to optimize or dicts defining parameter groups
base_optimizer (torch.optim.Optimizer) underlying optimizer that does the "sharpness-aware" update
rho (float, optional) size of the neighborhood for computing the max loss (default: 0.05)
**kwargs keyword arguments passed to the __init__ method of base_optimizer

SAM.first_step

Performs the first optimization step that finds the weights with the highest loss in the local rho-neighborhood.

Argument Description
zero_grad (bool, optional) set to True if you want to automatically zero-out all gradients after this step (default: False)

SAM.second_step

Performs the second optimization step that updates the original weights with the gradient from the (locally) highest point in the loss landscape.

Argument Description
zero_grad (bool, optional) set to True if you want to automatically zero-out all gradients after this step (default: False)

SAM.step

Performs both optimization steps in a single call. This function is an alternative to explicitly calling SAM.first_step and SAM.second_step.

Argument Description
closure (callable) the closure should do an additional full forward and backward pass on the optimized model (default: None)

Experiments

I've verified that SAM works on a simple WRN 16-8 model run on CIFAR10; you can replicate the experiment by running train.py. The Wide-ResNet is enhanced only by label smoothing and the most basic image augmentations with cutout, so the errors are higher than those in the SAM paper. Theoretically, you can get even lower errors by running for longer (1800 epochs instead of 200), because SAM shouldn't be as prone to overfitting.

Optimizer Test error rate
SGD + momentum 3.35 %
SAM + SGD + momentum 2.98 %
Note that the project description data, including the texts, logos, images, and/or trademarks, for each open source project belongs to its rightful owner. If you wish to add or remove any projects, please contact us at [email protected].