# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Global flags for aot autograd
"""
import os
import sys
from typing import Literal, Optional, TYPE_CHECKING

from torch.utils._config_module import Config, install_config_module


# Converts torch rng ops to their functional philox rng equivalents. Note that
# we functionalize only CUDA rng ops today.
functionalize_rng_ops = False

# can be useful for debugging if we are incorrectly creating meta fake tensors
fake_tensor_allow_meta = os.environ.get("FAKE_ALLOW_META", "1") != "0"

# Enables optional asserts in hotpath code to check for errors.  If
# you are seeing weird accuracy problems, try turning this on.
# This is currently off by default as it will harm tracing time,
# but it is on by default for aot_eager.
debug_assert = False

debug_partitioner = os.environ.get("AOT_PARTITIONER_DEBUG", "0") != "0"

# See # NOTE [Export custom triton op]
decompose_custom_triton_ops = True

static_weight_shapes = True

# See https://github.com/pytorch/pytorch/issues/141881
# Tells partitioner that parameters are free to save for backward.
treat_parameters_as_free_to_save = True

# Applies CSE to the graph before partitioning
cse = True

from torch._environment import is_fbcode


enable_autograd_cache: bool = Config(
    justknob="pytorch/remote_cache:enable_local_autograd_cache",
    env_name_force="TORCHINDUCTOR_AUTOGRAD_CACHE",
    default=True,
)

autograd_cache_allow_custom_autograd_functions: bool = Config(
    env_name_force="TORCHINDUCTOR_AUTOGRAD_CACHE_ALLOW_CUSTOM_AUTOGRAD", default=False
)

# For now, this is just for enabling unit testing in test_aot_autograd_cache.py
# We will either make this the default with AOTAutogradCache, or
# we'll just use it in the precompile flow. So there's no
# need to add env vars or make it configurable
bundled_autograd_cache: bool = False


def remote_autograd_cache_default() -> Optional[bool]:
    if os.environ.get("TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE") == "1":
        return True
    if os.environ.get("TORCHINDUCTOR_AUTOGRAD_REMOTE_CACHE") == "0":
        return False
    return None


enable_remote_autograd_cache = remote_autograd_cache_default()


# When AOTAutograd regenerates aliased graph outputs,
# attempt to use functionalization's view-replay logic
# before falling back to the autograd engine's view replay or as_strided.
# This can have some perf implications
# (although for many models this will not matter).
# (1) If you have many view ops chained together, replaying all of them
#     at runtime can have more overhead compared to a single as_strided call
# (2) If you are doing training, AsStridedBackward is quite slow,
#     and the individual view op backward formulas will likely be faster.
# (3) Some backends like XLA do not support as_strided

# Temporary hack: disable this flag for internal
# (needed to fix an internal issue while avoiding bumping XLA pin)
# eventually: either default this config to false completely
# once XLA pin update works,
# or default config to true and fix relevant bugs


# View replay is currently not compatible with AOTAutogradCache, since
# FunctionalTensors are not serializable. We'll need to make them
# serializable before enabling warm cache with this config turned on.
view_replay_for_aliased_outputs = not is_fbcode()

# Restricts the amount of computation AOTAutograd can do.
# NB: We have essentially disabled this heuristic now. However, this is kept
# here for now in case it's useful. Setting it low can artificially reduce the
# amount of recomputation AOTAutograd performs, although not in any kind of
# principled way.
max_dist_from_bw = 1000


# Bans recomputation of nodes that are reading from nodes that is far before
# the current node
ban_recompute_used_far_apart = True
# Breaks up long chain of fusible ops, as otherwise we can have an arbitrarily
# long chain of recomputation in the backwards pass.
ban_recompute_long_fusible_chains = True
# Bans recomputation of nodes that must be materialized in the backwards pass
# (used by a non-fusible node)
ban_recompute_materialized_backward = True
# Chooses to ban recomputation of nodes based off an allowlist. Setting it to
# False changes it to use a denylist. Main change is on operators like
# sort/pool/stuff that isn't cheap enough to be fusible for free but also isn't
# that expensive
ban_recompute_not_in_allowlist = True
# Chooses to ban recomputation of reductions. This is generally a good idea, as
# the result of reductions is generally very small but recomputing reductions in
# a fusion can be expensive.
ban_recompute_reductions = True
# Prevents the partitioner from ever saving views (i.e. always recompute them).
# Generally a good idea since views are free to recompute.
recompute_views = False

# By default, the partitioner is purely trying to optimize for runtime (although
# it should always use less memory than eager)
# This knob controls the partitioner to make that tradeoff for you, choosing the
# fastest option that saves less activations than the memory budget.
# Specifically, 0.0 corresponds to the activation memory from applying
# activation checkpointing to the full compiled region, and 1.0 corresponds to
# the activation memory from the default runtime-optimized strategy.  So, 0.4
# would result in a strategy that saves 40% of the activations compared to the
# default strategy.
# It solves a 0-1 knapsack to find the minimum recompute necessary to stay below
# the activation memory budget.
# NOTE: This *cannot* be treated as
activation_memory_budget = 1.0

# This controls how we estimate the runtime when deciding what the cheapest
# operators to recompute are. The 3 options are
# "flops": Bases it off of the flop count provided by torch.utils.flop_counter
# "profile": Benchmarks each operator to come up with a runtime
# "testing": Returns 1 for everything
activation_memory_budget_runtime_estimator = "flops"

# This controls the solver used for the 0-1 knapsack. By default we use a
# quantized DP solution ("dp"). The other approaches are a "greedy" and a "ilp"
# (which has a scipy dependency).
activation_memory_budget_solver = "dp"

# This dumps out a SVG visualization of the expected runtime vs. activation
# memory tradeoffs for all memory budget values from 0 to 1 in increments of
# 0.5. See an example here:
# https://github.com/pytorch/pytorch/pull/126320#discussion_r1625104015
visualize_memory_budget_pareto = (
    os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO", "0") == "1"
)

# This controls the directory in which to dump the SVG plot with the pareto
# frontier of the activation checkpointing memory-vs-runtime tradeoffs.
memory_budget_pareto_dir = os.environ.get("PARTITIONER_MEMORY_BUDGET_PARETO_DIR")

# Sets all of the ban_recompute heuristics to False except ban_recompute_reductions
# Generally, this will probably result in some memory improvement, but at the
# cost of some performance
aggressive_recomputation = False

# If FakeTensor.data_ptr() should error.
# This option is independent of AOTAutograd and torch.compile, but our policy
# is to turn it off during torch.compile.
fake_tensor_allow_unsafe_data_ptr_access = True

# Unlifts effect tokens from the inputs/outputs in the traced graph and instead
# inserts make_token/sink_token calls in the graph to create tokens and then
# sink them at the end. Note that this means the graph is no longer functional
# which may lead to silent errors unless the backend knows how to handle the
# tokens.
unlift_effect_tokens = False

# NOTE: [The default layout constraint for custom operators.]
# This must be the name of one of the layout constraint tags
# (that is, one of {"needs_fixed_stride_order", "flexible_layout"}),
# If the custom op does not have a layout constraint tag already
# then we assume the following applies.
#
# This config is respected by Inductor and we recommend other backends also
# respect it.
# This config is in torch._functorch and not torch._inductor because it affects
# ProxyTensor tracing.
custom_op_default_layout_constraint: Literal[
    "needs_exact_strides", "needs_fixed_stride_order", "flexible_layout"
] = "needs_exact_strides"


# Run aot eager decomp partition with CrossRefFakeMode
# options = False, "all", "custom_ops"
fake_tensor_crossref = False

# This mode specifies that we should also keep track of the real
# tensor along with the fake tensor, and do real compute.  While
# seemingly this eliminates the whole point of fake tensors, there are
# two obvious use cases for it:
#
#   1. When users call item()/other data dependent operations,
#      if we propagate_real_tensors we are able to determine what
#      the true value is and keep going.
#
#   2. It can be useful for testing, when you want to see if the fake
#      and real tensors agree with each other.  (Note that there are
#      currently known inaccuracies in how we clone real tensors, that
#      would have to be tightened up for this to be useful in this
#      case.)
#
# Note that fake tensors are typically understood to be cheap to store
# indefinitely, so we tend to hold on to them longer than we would
# hold onto the real tensors.  So we also support you explicitly
# deallocating the real tensor associated with a fake tensor, at which
# point we will stop propagating real tensors.
#
# One more thing: when you provide a real tensor to fakeify, we will
# clone it, so that we can safely perform mutations on it if necessary.
# This will increase live memory usage.  This could potentially be
# optimized by using COW.  We also currently do not faithfully
# maintain autograd metadata on the real tensor; this is fine because
# AOTAutograd will only use the fake tensor to determine leafness/etc
# of tensors in question.
fake_tensor_propagate_real_tensors = False

# This controls whether we collect donated buffer. This flag must be set
# False if a user wants to retain_graph=True for backward.
donated_buffer = False if is_fbcode() else True

# Controls the default graph output format used by draw_graph
# Supported formats are defined here https://graphviz.org/docs/outputs/
torch_compile_graph_format = os.environ.get("TORCH_COMPILE_GRAPH_FORMAT", "svg")

# Valid only if fake_tensor_propagate_real_tensors = True; if a fake-real
# kernel mismatch is detected, bypasses by making a fake kernel from the
# real tensor outputs.
generate_fake_kernels_from_real_mismatches = False

# CUDAGraph save run_with_rng functionalization.
# TODO: turn on by default
graphsafe_rng_functionalization = True


# Error on BypassAOTAutogradCache instead of just a warning
# Used for tests
strict_autograd_cache = False

# Note [Recomputing collectives in the partitioner]
# The purpose of this config is as follows:
# - We have many passes in the compiler (min-cut partitioning, DCE, etc)
#   which can reorder or ,delete duplicate nodes in the graph
# - If any of these passes reorder/delete/duplicate a collective
#   in a setting where the compiler is being run independently on multiple
#   ranks, we run the risk that the compiler will make a different decison on
#   different ranks, resulting in a NCCL hang when using torch.compile
# To handle this, we will (by default) ensure that collectives are not modified
# by the compiler.
#
# A few examples:
# - don't dead-code-eliminate collectives
#   (in case they are dead on rank i but not rank j)
# - don't recompute collectives in partitioning
#   (in case we recompute on rank i but not rank j)
#
# Today this flag **must** be set to false, but eventually
# we want the option to set it to true.
# In order to potentially optimize collectives, we'll need the compiler
# to broadcast information across ranks at compile time to ensure
# that any decisions on collectives are made consistently.
unsafe_allow_optimization_of_collectives = False

# See Note [AOTAutograd Tangent Subclassness for mutated inputs]
# TODO(ivankobzarev): Remove this config, being able to deduce it compile time.
disable_guess_zero_tangent_for_mutated_input_subclass = False

# See Note [Tangents memory format]
# By default tangents strideness is guessed to be contiguous,
# At runtime non contiguous tangents will be coerced to be contiguous.
# This config changes this guess for tangents strides to be the same as outputs.
# TODO(ivankobzarev): Remove this config once extra memory usage is investigated.
guess_tangent_strides_as_outputs = False

# This is a temporary config to ensure all ranks take the same decision in the partitioner
# it will untimately be removed once we share size_hints across ranks through compiler collectives
_broadcast_rank0_decision = False

# By default apply inlined saved_tensors_hooks only for "donated" buffers.
# "donated" buffers are invisible to the user, they are intermediates of the forward graph.
# Applying saved tensors hooks for memory optimizations only for intermediates
# guarantees that original saved tensors could be deallocated.
# This config enables saved_tensors_hooks are applied for **all** saved tensors,
# that could include inputs, parameters, outputs.
# "donated" - applied only to saved intermediates of the graph
# "no_static" - applied to all saved but not "static"
# (this includes parameters and user marked as static)
# "all" - no filtering, everything saved for backward.
saved_tensors_hooks_filtering_mode = "donated"


if TYPE_CHECKING:
    from torch.utils._config_typing import *  # noqa: F401, F403


# adds patch, save_config, invalid config checks, etc
install_config_module(sys.modules[__name__])
