LoRA#

init_lora_adapter#

def init_lora_adapter(weight: 'Tensor', rank: 'int', init_std: 'float' = 0.01, dtype: 'DType | None' = None) -> 'dict[str, Tensor]':

Initialize LoRA adapter matrices for a 2D linear weight.


lora_delta#

def lora_delta(adapter: 'dict[str, Tensor]', alpha: 'float' = 1.0) -> 'Tensor':

Compute scaled LoRA low-rank delta: (alpha / rank) * (A @ B).


lora_linear#

def lora_linear(x: 'Tensor', frozen_weight: 'Tensor', adapter: 'dict[str, Tensor]', alpha: 'float' = 1.0) -> 'Tensor':

Linear projection with frozen path + LoRA adapter path.


merge_lora_weight#

def merge_lora_weight(frozen_weight: 'Tensor', adapter: 'dict[str, Tensor]', alpha: 'float' = 1.0) -> 'Tensor':

Return merged weight: W + (alpha/r) * A @ B.


unmerge_lora_weight#

def unmerge_lora_weight(merged_weight: 'Tensor', adapter: 'dict[str, Tensor]', alpha: 'float' = 1.0) -> 'Tensor':

Recover frozen weight from merged weight and adapter.


tree_lora_delta#

def tree_lora_delta(adapters: 'Any', alpha: 'float' = 1.0, *, is_leaf: 'Any' = None) -> 'Any':

Map a pytree of LoRA adapter dicts to their low-rank deltas.