Shortcuts

Source code for torch.ao.ns.fx.utils

import enum
import operator

import torch
import torch.nn as nn
import torch.ao.nn.intrinsic.quantized as nniq
import torch.ao.nn.quantized as nnq

toq = torch.ops.quantized
from typing import Tuple, Callable, Dict, Set, List, Optional, Union

from torch.fx import GraphModule
from torch.fx.graph import Node
from torch.ao.quantization import (
    ObserverBase,
    FakeQuantizeBase,
)
from torch.ao.quantization.utils import getattr_from_fqn
from torch.ao.quantization.observer import _is_activation_post_process

from .ns_types import NSNodeTargetType, NSResultsType

# TODO(future PR): consider deleting this enum and using the torch types
# directly.  This might be tricky because it is not a one to one mapping.
class NodeInputOrOutputType(enum.Enum):
    FP32 = enum.auto()  # torch.float
    INT8 = enum.auto()  # torch.qint8 or torch.quint8
    FP16 = enum.auto()  # torch.float16
    UNKNOWN = enum.auto()  # we cannot determine input/output dtype
    # TODO(future PR): while these functions can support multiple dtypes,
    #   for the purposes of numerical debugging we want to get the actual
    #   dtype used in the model. We will likely need some kind of dtype
    #   propagation to estimate this.
    FP32_OR_INT8 = enum.auto()  # either torch.float or torch.quint8 or torch.qint8
    # TODO(future PRs): dynamic quant, fake quant, etc


def get_node_first_input_and_output_type(
    node: Node,
    gm: GraphModule,
    logger_cls: Callable,
    node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
) -> Tuple[NodeInputOrOutputType, NodeInputOrOutputType]:

    # TODO(future PR): clean this up
    FUNS_IO_TYPE_FP32 = node_type_to_io_type_map["funs_io_type_fp32"]
    FUNS_IO_TYPE_FP16 = node_type_to_io_type_map["funs_io_type_fp16"]
    FUNS_IO_TYPE_INT8 = node_type_to_io_type_map["funs_io_type_int8"]
    FUNS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["funs_io_type_fp32_or_int8"]
    MODS_IO_TYPE_FP32 = node_type_to_io_type_map["mods_io_type_fp32"]
    MODS_IO_TYPE_INT8 = node_type_to_io_type_map["mods_io_type_int8"]
    MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]
    METHS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["meths_io_type_fp32_or_int8"]

    if node.op == "call_function":
        if node.target in FUNS_IO_TYPE_FP32:
            return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
        if node.target in FUNS_IO_TYPE_FP16:
            return (NodeInputOrOutputType.FP16, NodeInputOrOutputType.FP16)
        elif node.target in FUNS_IO_TYPE_INT8:
            return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
        elif node.target in FUNS_IO_TYPE_FP32_OR_INT8:
            first_arg = get_normalized_nth_input(node, gm, 0)
            assert isinstance(first_arg, Node)
            (
                _prev_node_input_type,
                prev_node_output_type,
            ) = get_node_first_input_and_output_type(
                first_arg, gm, logger_cls, node_type_to_io_type_map
            )
            return (prev_node_output_type, prev_node_output_type)
        else:
            return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)

    elif node.op == "call_module":
        assert node.op == "call_module"
        assert isinstance(node.target, str)
        mod = getattr_from_fqn(gm, node.target)
        is_known_fp32_or_int8_input_module = any(
            isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8  # type: ignore[arg-type]
        )
        if (
            isinstance(mod, (logger_cls, ObserverBase, FakeQuantizeBase))  # type: ignore[arg-type]
            or is_known_fp32_or_int8_input_module
        ):
            # A logger or observer's input and output type is the output
            # type of the preceding node.
            first_arg = get_normalized_nth_input(node, gm, 0)
            assert isinstance(first_arg, Node)
            (
                _prev_node_input_type,
                prev_node_output_type,
            ) = get_node_first_input_and_output_type(
                first_arg, gm, logger_cls, node_type_to_io_type_map
            )
            return (prev_node_output_type, prev_node_output_type)
        is_known_fp32_input_module = any(
            isinstance(mod, target_type) for target_type in MODS_IO_TYPE_FP32  # type: ignore[arg-type]
        )
        is_known_int8_input_module = any(
            isinstance(mod, target_type) for target_type in MODS_IO_TYPE_INT8  # type: ignore[arg-type]
        )
        if is_known_fp32_input_module:
            return (NodeInputOrOutputType.FP32, NodeInputOrOutputType.FP32)
        elif is_known_int8_input_module:
            return (NodeInputOrOutputType.INT8, NodeInputOrOutputType.INT8)
        else:
            return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)

    elif node.op == "call_method":
        if node.target == "dequantize":
            # Dequantize is a special node because it allows multiple input types.
            # So, we look up the output type of the previous node and return that
            # as the input type of this node instance.
            prev_node = get_normalized_nth_input(node, gm, 0)
            assert isinstance(prev_node, Node)
            (
                _prev_node_input_type,
                prev_node_output_type,
            ) = get_node_first_input_and_output_type(
                prev_node, gm, logger_cls, node_type_to_io_type_map
            )
            return (prev_node_output_type, NodeInputOrOutputType.FP32)

        elif node.target == "to":
            # to is a special node because it allows multiple input types.
            # So, we look up the output type of the previous node and return that
            # as the input type of this node instance. We also look up the target
            # of to and return the correct output type.
            prev_node = get_normalized_nth_input(node, gm, 0)
            assert isinstance(prev_node, Node)
            (
                _prev_node_input_type,
                prev_node_output_type,
            ) = get_node_first_input_and_output_type(
                prev_node, gm, logger_cls, node_type_to_io_type_map
            )

            cur_node_dtype_target = get_normalized_nth_input(node, gm, 1)
            assert (
                cur_node_dtype_target is torch.float16
            ), f"{cur_node_dtype_target} handling needs to be added"

            return (prev_node_output_type, NodeInputOrOutputType.FP16)

        elif node.target in METHS_IO_TYPE_FP32_OR_INT8:
            first_arg = get_normalized_nth_input(node, gm, 0)
            assert isinstance(first_arg, Node)
            (
                _prev_node_input_type,
                prev_node_output_type,
            ) = get_node_first_input_and_output_type(
                first_arg, gm, logger_cls, node_type_to_io_type_map
            )
            return (prev_node_output_type, prev_node_output_type)

        return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)
    else:
        return (NodeInputOrOutputType.UNKNOWN, NodeInputOrOutputType.UNKNOWN)


def get_node_input_qparams(
    node: Node,
    gm: GraphModule,
    node_type_to_io_type_map: Dict[str, Set[NSNodeTargetType]],
) -> Optional[Tuple[Union[torch.Tensor, float], Union[torch.Tensor, int]]]:
    """
    Returns the qparams (scale, zero_point) of the first input to `node`,
    if they can be inferred from the graph.
    """
    prev_node = get_normalized_nth_input(node, gm, 0)

    if not isinstance(prev_node, Node):
        return None

    MODS_IO_TYPE_FP32_OR_INT8 = node_type_to_io_type_map["mods_io_type_fp32_or_int8"]

    def _get_scale_zp_from_function_args(node, gm, scale_arg_idx, zp_arg_idx):
        scale_node = get_normalized_nth_input(node, gm, scale_arg_idx)
        zp_node = get_normalized_nth_input(node, gm, zp_arg_idx)
        assert isinstance(scale_node, Node) and isinstance(scale_node.target, str)
        assert isinstance(zp_node, Node) and isinstance(zp_node.target, str)
        scale_obj = getattr_from_fqn(gm, scale_node.target)
        zp_obj = getattr_from_fqn(gm, zp_node.target)
        return (scale_obj, zp_obj)

    if prev_node.op == "call_function":

        # quantize - read the args directly
        if prev_node.target == torch.quantize_per_tensor:
            return _get_scale_zp_from_function_args(prev_node, gm, 1, 2)
        elif prev_node.target in (toq.add, toq.add_relu, toq.mul, toq.mul_relu):
            return _get_scale_zp_from_function_args(prev_node, gm, 2, 3)

        return None
        # TODO(future PR): handle more functionals
        # TODO(future PR): handle functional ops which inherit qparams from input

    elif prev_node.op == "call_module":

        # get type of the module
        assert isinstance(prev_node.target, str)
        module_obj = getattr_from_fqn(gm, prev_node.target)
        if isinstance(
            module_obj,
            (
                nnq.Linear,
                nnq.Conv1d,
                nnq.Conv2d,
                nniq.ConvReLU2d,
                nnq.Conv3d,
                nnq.BatchNorm2d,
                nnq.BatchNorm3d,
                nnq.ConvTranspose1d,
                nnq.ConvTranspose2d,
                nnq.ELU,
                nnq.GroupNorm,
                nnq.InstanceNorm1d,
                nnq.InstanceNorm2d,
                nnq.InstanceNorm3d,
                nnq.LayerNorm,
                nnq.Hardswish,
                nnq.LeakyReLU,
                nnq.ReLU6,
                nniq.BNReLU2d,
                nniq.BNReLU3d,
                nniq.ConvReLU1d,
                nniq.ConvReLU2d,
                nniq.ConvReLU3d,
                nniq.LinearReLU,
            ),
        ):
            return (module_obj.scale, module_obj.zero_point)  # type: ignore[return-value]

        is_known_fp32_or_int8_input_module = any(
            isinstance(module_obj, target_type) for target_type in MODS_IO_TYPE_FP32_OR_INT8  # type: ignore[arg-type]
        )
        if is_known_fp32_or_int8_input_module:
            return get_node_input_qparams(prev_node, gm, node_type_to_io_type_map)

    return None


def return_first_non_observer_node(
    node: Node,
    gm: GraphModule,
) -> Node:
    """
    If node is not an observer, returns it.  If node is an observer,
    navigates up the graph and returns the first parent which is not an
    observer.  For example,

    graph: (node_non_obs), node = node_non_obs : returns node_non_obs
    graph: (node_non_obs -> obs0), node = obs0 : returns node_non_obs
    graph: (node_non_obs -> obs0 -> fq0), node = fq0 : returns node_non_obs
    """
    if node.op == "call_module":
        node_obj = getattr_from_fqn(gm, node.target)  # type: ignore[arg-type]
        if _is_activation_post_process(node_obj):
            assert len(node.args) == 1
            assert isinstance(node.args[0], Node)
            node = node.args[0]
            # code duplication intended, not worth refactoring
            assert isinstance(node.target, str)
            node_obj = getattr_from_fqn(gm, node.target)
            if _is_activation_post_process(node_obj):
                assert len(node.args) == 1
                assert isinstance(node.args[0], Node)
                node = node.args[0]
    return node


def get_number_of_non_param_args(
    node: Node,
    gm: GraphModule,
) -> int:
    """
    Assumes that all non-param args occur first. Returns the number of
    non-param args expected for a node.  For example, for

      F.linear(x, weight, bias)

    Returns 1, because x is a non-param arg and weight and bias are params.
    For

      lstm_mod(x, hid)

    Returns 2, because both x and hid are non-param args.
    """
    if node.op == "call_module":
        node_obj = getattr_from_fqn(gm, node.target)  # type: ignore[arg-type]
        if isinstance(node_obj, nn.LSTM):
            return 2

    # default is 1
    return 1


def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]:
    """
    Returns the indices of args of the node which we should attach
    loggers to, if input logging is enabled.

    For example,
    * for (x + y), returns [0, 1]
    * for (1 + y), returns [1]
    * for (x + 1), returns [0]
    * for (linear(x, w, b)) returns [0]
    * by default, returns [0]
    """
    if len(node.args) == 0:
        return []
    if node.op == "call_function" and (
        # TODO(future PR): use relationship map instead of hardcoding
        node.target in (torch.add, torch.ops.quantized.add, operator.add)
        or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
    ):
        result = []
        for i in range(2):
            if type(node.args[i]) == Node:
                result.append(i)
        return result
    return [0]


def get_target_type_str(node: Node, gm: GraphModule) -> str:
    """
    Returns a string representation of the type of the function or module
    pointed to by this node, or '' for other node types.
    """
    target_type = ""
    if node.op in ("call_function", "call_method"):
        target_type = torch.typename(node.target)
    elif node.op == "call_module":
        assert isinstance(node.target, str)
        target_mod = getattr_from_fqn(gm, node.target)
        target_type = torch.typename(target_mod)
    return target_type


def rekey_logger_info_on_node_name_of_model(
    results: NSResultsType,
    model_name: str,
) -> NSResultsType:
    """
    Rekeys the layer name of a results dictionary to use node names
    from `model_name`.

    For example, transforms

        {'base_op_1_0': {'node_output': {'model_a':
          [{'ref_node_name': 'linear1', ...}]}}}

    into

        {'linear1': {'node_output': {'model_a':
          [{'ref_node_name': 'linear1', ...}]}}}

    Note: we cannot use these node names directly because they are not
    guaranteed to be consistent across models. This is why we extract
    the results first and rekey afterwards.
    """
    new_results = {}
    for old_layer_name, result_type_to_results in results.items():
        new_layer_name = None
        for _result_type, model_name_to_results in result_type_to_results.items():
            for cur_model_name, list_of_results in model_name_to_results.items():
                if cur_model_name == model_name:
                    assert len(list_of_results)
                    new_layer_name = list_of_results[0]["ref_node_name"]
                else:
                    continue
        if new_layer_name is not None:
            new_results[new_layer_name] = result_type_to_results
        else:
            new_results[old_layer_name] = result_type_to_results
    return new_results


def maybe_add_missing_fqns(results: NSResultsType) -> None:
    """
    If `fqn` entries are filled in for one of the models in `results`, copies
    them over to any models which do not have them filled out.

    A common use case benefitting from this is comparing a model prepared by
    quantization to a quantized model. In this case, the model prepared by
    quantization would have `fqn` entries, and the quantized model would not.
    """

    # Check in the first result to find any model with fqn entries defined.
    model_name_with_fqns = None
    for layer_name, result_type_to_results in results.items():
        for result_type, model_name_to_results in result_type_to_results.items():
            for model_name, model_results in model_name_to_results.items():
                if len(model_results) > 0:
                    if model_results[0]["fqn"] is not None:
                        model_name_with_fqns = model_name
                        break
            break
        break

    if model_name_with_fqns:
        for layer_name, result_type_to_results in results.items():
            for result_type, model_name_to_results in result_type_to_results.items():
                ref_model_results = model_name_to_results[model_name_with_fqns]
                for model_name, model_results in model_name_to_results.items():
                    if model_name == model_name_with_fqns:
                        continue
                    for i in range(len(model_results)):
                        fqn = ref_model_results[i]["fqn"]
                        model_results[i]["fqn"] = fqn


def maybe_dequantize_first_two_tensor_args_and_handle_tuples(f):
    def inner(*args, **kwargs):
        a0, a1, *a_other = args

        if (isinstance(a0, tuple) and isinstance(a1, tuple)) or (
            isinstance(a0, list) and isinstance(a1, list)
        ):
            results = []
            for el0, el1 in zip(a0, a1):
                new_args = (el0, el1, *a_other)
                results.append(inner(*new_args, **kwargs))
            return results

        elif isinstance(a0, torch.Tensor) and isinstance(a1, torch.Tensor):
            if a0.is_quantized:
                a0 = a0.dequantize()
            if a1.is_quantized:
                a1 = a1.dequantize()

        # for the purposes of this util, only handle floats
        if a0.dtype != torch.float or a1.dtype != torch.float:
            return None

        new_args = (a0, a1, *a_other)
        return f(*new_args, **kwargs)

    return inner


[docs]@maybe_dequantize_first_two_tensor_args_and_handle_tuples def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Computes the SQNR between `x` and `y`. Args: x: Tensor or tuple of tensors y: Tensor or tuple of tensors Return: float or tuple of floats """ Ps = torch.norm(x) Pn = torch.norm(x - y) return 20 * torch.log10(Ps / Pn)
[docs]@maybe_dequantize_first_two_tensor_args_and_handle_tuples def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Computes the normalized L2 error between `x` and `y`. Args: x: Tensor or tuple of tensors y: Tensor or tuple of tensors Return: float or tuple of floats """ return torch.sqrt(((x - y) ** 2).sum() / (x ** 2).sum())
[docs]@maybe_dequantize_first_two_tensor_args_and_handle_tuples def compute_cosine_similarity(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Computes the cosine similarity between `x` and `y`. Args: x: Tensor or tuple of tensors y: Tensor or tuple of tensors Return: float or tuple of floats """ # For convolutions, the shape of the quantized weight has one additional # dimension compared to the shape of the fp32 weight. Match the shapes # to enable cosine similarity comparison. x = x.reshape(1, -1) y = y.reshape(1, -1) return torch.nn.functional.cosine_similarity(x, y)
def op_type_supports_shadowing(node: Node) -> bool: if node.op == 'call_function': if node.target in (torch.add, torch.mul, operator.add, operator.mul, torch.cat, torch.stack): # shadowing for ops with multiple tensor inputs is not implemented yet return False return True def get_normalized_nth_input(node: Node, gm: GraphModule, idx: int) -> Node: """ Given a node, gets the n'th input to that node, normalizing args and kwargs to the best of its ability. """ try: norm_args_and_kwargs = node.normalized_arguments( gm, normalize_to_only_use_kwargs=True) if norm_args_and_kwargs is not None: norm_args, norm_kwargs = norm_args_and_kwargs assert len(norm_args) + len(norm_kwargs) > idx if idx < len(norm_args): return norm_args[idx] else: # note: in Python 3.7+ dicts are ordered return list(norm_kwargs.values())[idx] else: assert len(node.args) + len(node.kwargs) > idx if idx < len(node.args): return node.args[idx] # type: ignore[return-value] else: kwargs_idx = idx + len(node.args) return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value] except RuntimeError: # this RuntimeError happens when node argument normalization # requires typehints to proceed, such as for torch.add where # either the first, second or both arguments could be tensors assert len(node.args) + len(node.kwargs) > idx if idx < len(node.args): return node.args[idx] # type: ignore[return-value] else: kwargs_idx = idx + len(node.args) return list(node.kwargs.values())[kwargs_idx] # type: ignore[return-value]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources