Spaces:
Running
on
Zero
Running
on
Zero
| import math | |
| from mmcv import Config | |
| from mmcv.runner import build_optimizer as mm_build_optimizer, OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, \ | |
| OPTIMIZERS | |
| from mmcv.utils import _BatchNorm, _InstanceNorm | |
| from torch.nn import GroupNorm, LayerNorm | |
| from .logger import get_root_logger | |
| from typing import Tuple, Optional, Callable | |
| import torch | |
| from torch.optim.optimizer import Optimizer | |
| def auto_scale_lr(effective_bs, optimizer_cfg, rule='linear', base_batch_size=256): | |
| assert rule in ['linear', 'sqrt'] | |
| logger = get_root_logger() | |
| # scale by world size | |
| if rule == 'sqrt': | |
| scale_ratio = math.sqrt(effective_bs / base_batch_size) | |
| elif rule == 'linear': | |
| scale_ratio = effective_bs / base_batch_size | |
| optimizer_cfg['lr'] *= scale_ratio | |
| logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.7f} (using {rule} scaling rule).') | |
| return scale_ratio | |
| class MyOptimizerConstructor(DefaultOptimizerConstructor): | |
| def add_params(self, params, module, prefix='', is_dcn_module=None): | |
| """Add all parameters of module to the params list. | |
| The parameters of the given module will be added to the list of param | |
| groups, with specific rules defined by paramwise_cfg. | |
| Args: | |
| params (list[dict]): A list of param groups, it will be modified | |
| in place. | |
| module (nn.Module): The module to be added. | |
| prefix (str): The prefix of the module | |
| """ | |
| # get param-wise options | |
| custom_keys = self.paramwise_cfg.get('custom_keys', {}) | |
| # first sort with alphabet order and then sort with reversed len of str | |
| # sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) | |
| bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.) | |
| bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.) | |
| norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.) | |
| bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False) | |
| # special rules for norm layers and depth-wise conv layers | |
| is_norm = isinstance(module, | |
| (_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) | |
| for name, param in module.named_parameters(recurse=False): | |
| base_lr = self.base_lr | |
| if name == 'bias' and not is_norm and not is_dcn_module: | |
| base_lr *= bias_lr_mult | |
| # apply weight decay policies | |
| base_wd = self.base_wd | |
| # norm decay | |
| if is_norm: | |
| if self.base_wd is not None: | |
| base_wd *= norm_decay_mult | |
| elif name == 'bias' and not is_dcn_module: | |
| if self.base_wd is not None: | |
| # TODO: current bias_decay_mult will have affect on DCN | |
| base_wd *= bias_decay_mult | |
| param_group = {'params': [param]} | |
| if not param.requires_grad: | |
| param_group['requires_grad'] = False | |
| params.append(param_group) | |
| continue | |
| if bypass_duplicate and self._is_in(param_group, params): | |
| logger = get_root_logger() | |
| logger.warn(f'{prefix} is duplicate. It is skipped since ' | |
| f'bypass_duplicate={bypass_duplicate}') | |
| continue | |
| # if the parameter match one of the custom keys, ignore other rules | |
| is_custom = False | |
| for key in custom_keys: | |
| scope, key_name = key if isinstance(key, tuple) else (None, key) | |
| if scope is not None and scope not in f'{prefix}': | |
| continue | |
| if key_name in f'{prefix}.{name}': | |
| is_custom = True | |
| if 'lr_mult' in custom_keys[key]: | |
| # if 'base_classes' in f'{prefix}.{name}' or 'attn_base' in f'{prefix}.{name}': | |
| # param_group['lr'] = self.base_lr | |
| # else: | |
| param_group['lr'] = self.base_lr * custom_keys[key]['lr_mult'] | |
| elif 'lr' not in param_group: | |
| param_group['lr'] = base_lr | |
| if self.base_wd is not None: | |
| if 'decay_mult' in custom_keys[key]: | |
| param_group['weight_decay'] = self.base_wd * custom_keys[key]['decay_mult'] | |
| elif 'weight_decay' not in param_group: | |
| param_group['weight_decay'] = base_wd | |
| if not is_custom: | |
| # bias_lr_mult affects all bias parameters | |
| # except for norm.bias dcn.conv_offset.bias | |
| if base_lr != self.base_lr: | |
| param_group['lr'] = base_lr | |
| if base_wd != self.base_wd: | |
| param_group['weight_decay'] = base_wd | |
| params.append(param_group) | |
| for child_name, child_mod in module.named_children(): | |
| child_prefix = f'{prefix}.{child_name}' if prefix else child_name | |
| self.add_params( | |
| params, | |
| child_mod, | |
| prefix=child_prefix, | |
| is_dcn_module=is_dcn_module) | |
| def build_optimizer(model, optimizer_cfg): | |
| # default parameter-wise config | |
| logger = get_root_logger() | |
| if hasattr(model, 'module'): | |
| model = model.module | |
| # set optimizer constructor | |
| optimizer_cfg.setdefault('constructor', 'MyOptimizerConstructor') | |
| # parameter-wise setting: cancel weight decay for some specific modules | |
| custom_keys = dict() | |
| for name, module in model.named_modules(): | |
| if hasattr(module, 'zero_weight_decay'): | |
| custom_keys |= { | |
| (name, key): dict(decay_mult=0) | |
| for key in module.zero_weight_decay | |
| } | |
| paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys))) | |
| if given_cfg := optimizer_cfg.get('paramwise_cfg'): | |
| paramwise_cfg.merge_from_dict(dict(cfg=given_cfg)) | |
| optimizer_cfg['paramwise_cfg'] = paramwise_cfg.cfg | |
| # build optimizer | |
| optimizer = mm_build_optimizer(model, optimizer_cfg) | |
| weight_decay_groups = dict() | |
| lr_groups = dict() | |
| for group in optimizer.param_groups: | |
| if not group.get('requires_grad', True): continue | |
| lr_groups.setdefault(group['lr'], []).append(group) | |
| weight_decay_groups.setdefault(group['weight_decay'], []).append(group) | |
| learnable_count, fix_count = 0, 0 | |
| for p in model.parameters(): | |
| if p.requires_grad: | |
| learnable_count += 1 | |
| else: | |
| fix_count += 1 | |
| fix_info = f"{learnable_count} are learnable, {fix_count} are fix" | |
| lr_info = "Lr group: " + ", ".join([f'{len(group)} params with lr {lr:.5f}' for lr, group in lr_groups.items()]) | |
| wd_info = "Weight decay group: " + ", ".join( | |
| [f'{len(group)} params with weight decay {wd}' for wd, group in weight_decay_groups.items()]) | |
| opt_info = f"Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}." | |
| logger.info(opt_info) | |
| return optimizer | |
| class Lion(Optimizer): | |
| def __init__( | |
| self, | |
| params, | |
| lr: float = 1e-4, | |
| betas: Tuple[float, float] = (0.9, 0.99), | |
| weight_decay: float = 0.0, | |
| ): | |
| assert lr > 0. | |
| assert all(0. <= beta <= 1. for beta in betas) | |
| defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) | |
| super().__init__(params, defaults) | |
| def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): | |
| # stepweight decay | |
| p.data.mul_(1 - lr * wd) | |
| # weight update | |
| update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_() | |
| p.add_(update, alpha=-lr) | |
| # decay the momentum running average coefficient | |
| exp_avg.lerp_(grad, 1 - beta2) | |
| def exists(val): | |
| return val is not None | |
| def step( | |
| self, | |
| closure: Optional[Callable] = None | |
| ): | |
| loss = None | |
| if self.exists(closure): | |
| with torch.enable_grad(): | |
| loss = closure() | |
| for group in self.param_groups: | |
| for p in filter(lambda p: self.exists(p.grad), group['params']): | |
| grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \ | |
| self.state[p] | |
| # init state - exponential moving average of gradient values | |
| if len(state) == 0: | |
| state['exp_avg'] = torch.zeros_like(p) | |
| exp_avg = state['exp_avg'] | |
| self.update_fn( | |
| p, | |
| grad, | |
| exp_avg, | |
| lr, | |
| wd, | |
| beta1, | |
| beta2 | |
| ) | |
| return loss | |