Solving Reproducibility Challenges in Deep Learning and LLMs: Our Journey

Published on: 
Sep 22, 2024

Reproducibility in deep learning has long been a critical concern. Whether it’s in autonomous vehicles, medical systems, or simply replicating the results of a scientific experiment, ensuring consistency is key. This challenge has become even more prominent with the rise of large language models (LLMs), particularly when we want our models to consistently follow a specific pattern, for example in safety-critical applications.

This is a well-known problem (e.g., Determinism in Deep Learning, Reproducible Deep Learning Using PyTorch), and to the best of our knowledge, there is still no canonical solution for the general case — only best practices to mitigate its impact. Citing from the documentation of Pytorch on Reproducibility:

“Completely reproducible results are not guaranteed across PyTorch releases, individual commits, or different platforms. Furthermore, results may not be reproducible between CPU and GPU executions, even when using identical seeds.”

“Create an image of a red apple in studio lighting”

Our Motivation

Non-determinism is a major obstacle when incorporating a zero-knowledge proof (ZKP) into inference or training for accountability, regulatory, or privacy purposes:

Classical ZK provers rely on perfectly deterministic arithmetic constraints (e.g., R1CS, Plonkish, GKR). Without determinism, the proof would be for a trace that can’t be reproduced or verified.

As a highly parallel computation, ZK provers rely extensively on hardware accelerators, with NVIDIA GPUs — commonly used in AI — playing a leading role. Unfortunately, the cost of proving a non-deterministic computation is orders of magnitude higher than the already expensive ZKP process. Most ZKP use cases require a public verifier, meaning any CUDA-enabled device must be able to reproduce the AI computation, without enforcing a specific hardware architecture. For more details see Tomer’s Encode talk.

In this blog post we detail our journey toward ensuring that our deep learning models consistently produce reproducible results.

The Root Cause: Non-Associative Nature of Floating-Point Arithmetic

At the heart of our reproducibility issues was the non-associative trait of floating-point numbers. Floating-point arithmetic lacks perfect precision because computers represent decimal numbers in binary form. Operations such as addition and multiplication can produce slightly different results based on the order of execution and the hardware used. As a result, associativity is not always guaranteed: (a + b) + c ≠ a + (b + c)

By controlling the order in which operations are performed, we can ensure that the computations are consistent across runs and hardware.

This non-associativity means that even with identical initial conditions, small computational differences can accumulate over time, resulting in noticeable discrepancies in model outputs. This is especially problematic in deep learning, where models perform billions of such operations.

In our experiment we focus on running DNNs on Nvidia hardware using CUDA and its derivatives such as cuBLAS. This setup, although very common, does not guarantee reproducibility.

Our Setup

We conducted our tests on three different machines, each with a different GPU and operating system:

  • Ubuntu 20 with an NVIDIA RTX 3090
  • Ubuntu 22.04 with an NVIDIA RTX 4080
  • Centos with an NVIDIA L4

We selected these GPUs to test our thesis on cards from different generations and architectures (Ampere and Ada Lovelace), as well as for different use cases (gaming vs. professional). All machines were running CUDA Toolkit 12.0.

We investigated the issue across multiple frameworks, starting with PyTorch, a leading platform for deep learning development, and continuing with llama.cpp, which is gaining traction as a popular framework for LLM inference, particularly for quantized models.

Using Reproducibility Flags

Our first approach was to employ reproducibility flags provided by deep learning frameworks like PyTorch and control random number generators. Below is the code snippet we used:

Env Variable:
export CUBLAS_WORKSPACE_CONFIG=:4096:8
import random
import numpy as np
import torch

random.seed(0)  # Sets the seed for Python's built-in random module
np.random.seed(0)  # Sets the seed for NumPy's random number generator

torch.manual_seed(0)  # Sets the seed for PyTorch's CPU random number generator
torch.cuda.manual_seed(0)  # Sets the seed for the current GPU device
torch.cuda.manual_seed_all(0)  # Sets the seed for all available GPU devices

torch.use_deterministic_algorithms(True)  # Ensures that only deterministic algorithms are used

torch.backends.cuda.matmul.allow_tf32 = False  # Disables TensorFloat32 (TF32) on matmul ops
torch.backends.cudnn.allow_tf32 = False  # Disables TF32 on cuDNN
torch.backends.cudnn.benchmark = False  # Disables the cuDNN auto-tuner
torch.backends.cudnn.deterministic = True  # Forces cuDNN to use deterministic algorithms
#in the worst case, use this:
torch.backends.cudnn.enabled = False  # Disables cuDNN entirely

Explanation of Each Flag:

  • random.seed(0): Sets the seed for Python’s built-in random number generator, ensuring that any use of random produces the same results each run.
  • np.random.seed(0): Sets the seed for NumPy’s random number generator for consistency in NumPy operations.
  • torch.manual_seed(0): Sets the seed for PyTorch’s CPU random number generator.
  • torch.cuda.manual_seed(0) and torch.cuda.manual_seed_all(0): Set the seed for GPU(s), ensuring that CUDA operations are deterministic.
  • torch.use_deterministic_algorithms(True): Forces PyTorch to use deterministic algorithms when available; if a deterministic algorithm is not available, it will raise an error.
  • torch.backends.cuda.matmul.allow_tf32 = False and torch.backends.cudnn.allow_tf32 = False: Disable TensorFloat32 (TF32) operations on matmul and cuDNN, ensuring computations use FP32 precision, which is more precise and consistent.
  • torch.backends.cudnn.benchmark = False: Disables the cuDNN auto-tuner that selects the fastest convolution algorithm; this can introduce non-determinism due to varying algorithm choices.
  • torch.backends.cudnn.deterministic = True: Forces cuDNN to use deterministic algorithms, which may be slower but are consistent across runs.
  • torch.backends.cudnn.enabled = False: Disables cuDNN entirely, ensuring that only pure PyTorch operations are used. This can significantly impact performance but increases reproducibility.

Outcome: While these flags helped reduce variability, they did not completely eliminate differences across machines with varying hardware configurations. The models still produced slightly different outputs on our platforms.

Despite setting all possible seeds and disabling non-deterministic behaviors, the inherent non-associativity of floating-point operations, combined with hardware differences, continued to cause inconsistencies.

Running Quantized Models with Llama.cpp

Next we turned to llama.cpp, a project that allows running LLMs with quantized weights to significantly reduce computational load. Llama.cpp has defined a popular quantization format called GGUF, which supports various levels of quantization, including int8.

Expectation: We expected that running the model in a quantized format, such as int8, would resolve the reproducibility issue. Our assumption was that using integers instead of floating-point numbers would eliminate the inconsistencies.

Reality: However, we encountered two critical issues:

  1. Incomplete Quantization: Not all layers in the model are quantized. Specifically, 1D normalization layers (such as LayerNorm) remain unquantized, continuing to perform computations in floating-point precision, which reintroduces non-determinism. Source
  2. Runtime Dequantization: Even for the quantized layers, weights are dequantized during inference to accommodate operations like calculating logits. As a result, despite storing weights in a quantized format, the model still relies on floating-point computations at runtime, which are subject to non-associativity and hardware variability.

As a result, despite using quantization and the GGUF format, we were unable to achieve reproducibility across different machines. The combination of quantized and unquantized computations, along with runtime dequantization, reintroduced the very floating-point operations we sought to eliminate!

Conclusion from This Attempt: Quantization alone does not ensure reproducibility. The presence of unquantized layers and the need to dequantize weights during inference mean that floating-point operations — and their inherent non-determinism — remain part of the computation pipeline.

The Breakthrough: Rewriting the GEMM CUDA Kernel

Realizing the issue might stem from the matrix multiplication operations, we decided to address the problem at the kernel level.

After compiling the code for different architectures, we traced the source of non-determinism down to the PTX files (essentially CUDA machine code). We observed that some kernels generate different instructions depending on the target architecture. As a result, we began investigating the matrix multiplication kernels, which are central to our AI computations.

We conducted an experiment isolating the GEMM kernel.

The GEMM kernels (link) provided by the CuBLAS library are among the most widely used and optimized kernels for deep learning and parallel computation acceleration. However, this high level of optimization sometimes comes at the cost of reproducibility, leading to different results across various hardware.

When we ran these kernels on different platforms (L4, 3090, 4080), the results differed with an error margin of 1e-4. While this may seem like a small discrepancy, in the context of LLMs, where each token depends on the previous one, such errors quickly accumulate and result in divergent outputs.

Implementing a Deterministic GEMM CUDA Kernel

We rewrote the General Matrix Multiply (GEMM) CUDA kernels used by llama.cpp which had been identified as non-deterministic. Our goal was to provide a more simple, deterministic version of each kernel, without compromising on performance. Our strategy was:

  • Avoiding Tensor Cores and Using Only CUDA Cores: Tensor cores allow Nvidia to keep innovating, updating their architecture and capabilities in every generation while CUDA cores reliably offer backwards compatibility when needed. By restricting computations to CUDA cores, we ensured that operations were executed in a consistent manner across different architectures, as different architectures might have different generations of Tensor cores. Further reading material is provided here.
  • Ensuring Deterministic Order of Operations: We carefully managed the order of floating-point operations to eliminate non-associativity effects.

Examples of our kernels: C/C++

__global__ void ingo_mul_mat_fp16_fp16_kernel(const half * __restrict__ A, const half * __restrict__ B, half * __restrict__ C, int m, int n, int k, int lda, int ldb, int ldc, bool transpose_a) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    if (row < m && col < n) {
        half sum = 0.0f;
        if (!transpose_a) {
            // No transpose: Normal matrix multiplication A * B
            // A is accessed in column-major, B is also accessed in column-major
            for (int i = 0; i < k; i++) {
                sum += A[i * lda + row] * B[col * ldb + i];
            }
        } else {
            // Transpose A: Multiply A^T * B
            // A^T is actually accessed as if A were in row-major
            for (int i = 0; i < k; i++) {
                sum += A[row * lda + i] * B[col * ldb + i];
            }
        }
        C[col * ldc + row] = sum;  // Storing result in column-major order
    }
}

void ingo_mul_mat(const half * A, const half * B, half * C,
                           int m, int n, int k, int lda, int ldb, int ldc, bool transpose_a, cudaStream_t stream) {
    dim3 blockSize(16, 16);
    dim3 gridSize((n + blockSize.x - 1) / blockSize.x, (m + blockSize.y - 1) / blockSize.y);

    ingo_mul_mat_fp16_fp16_kernel<<<gridSize, blockSize, 0, stream>>>(A, B, C, m, n, k, lda, ldb, ldc, transpose_a);
    checkCudaError("kernel launch");

    cudaDeviceSynchronize();
    checkCudaError("synchronizing after kernel execution");
}

Outcome: This approach successfully resolved our reproducibility issues! Models running on all three machines produced identical outputs, confirming that we had mitigated hardware-induced variability. For more insights on GEMM kernel optimization, check out this excellent blog post by Simon Boehm from Anthropic.

We have tested our results across the following models:

  • llama2–7b-fp16
  • llama2–7b-Q8
  • mistral-7b-instruct-v0.2.Q8_0
  • mixtral-8x7b.Q2
  • Mistrall-7b-instruct-v0.2.fp_16
  • Meta-Llama-3–8B-fp16-gguf
  • Meta-Llama-3–8B.Q8_0.gguf

Furthermore, we extended our experiment by evaluating the performance of our deterministic kernels. After a round of basic optimizations, handling batches over multiple threads, we observed that performance was primarily affected during the prompt processing stage (302 tps vs. 43 tps), while remaining consistent at 44–45 tps during the text generation phase. With further optimization, this can likely be improved.

These measurements were taken using llama.cpp inference code running on an RTX 4080.

Original llama.cpp: C/C++

llama_print_timings:        load time =    1311.27 ms
llama_print_timings:      sample time =       5.33 ms /   300 runs   (    0.02 ms per token, 56285.18 tokens per second)
llama_print_timings: prompt eval time =      23.14 ms /     7 tokens (    3.31 ms per token,   302.48 tokens per second)
llama_print_timings:        eval time =    6627.57 ms /   299 runs   (   22.17 ms per token,    45.11 tokens per second)
llama_print_timings:       total time =    6702.29 ms /   306 tokens
llama_print_timings:        load time =    1322.69 ms
llama_print_timings:      sample time =       5.18 ms /   300 runs   (    0.02 ms per token, 57881.54 tokens per second)
llama_print_timings: prompt eval time =     161.55 ms /     7 tokens (   23.08 ms per token,    43.33 tokens per second)
llama_print_timings:        eval time =    6739.35 ms /   299 runs   (   22.54 ms per token,    44.37 tokens per second)
llama_print_timings:       total time =    6963.47 ms /   306 tokens

Conclusion

  • Hardware Matters: Differences in GPU architectures can affect reproducibility, even when running the same code.
  • Quantization Isn’t a Silver Bullet: While quantization reduces model size and computation, it doesn’t inherently solve reproducibility issues.
  • Deep Control Is Necessary: Sometimes, ensuring reproducibility requires going deeper into the stack, even down to kernel implementations.

Reproducibility in deep learning is a multifaceted challenge that requires attention to both software and hardware details. By understanding the root causes and being willing to delve into low-level implementations, we can overcome these challenges.

Our journey underscores the importance of meticulous engineering and provides a roadmap for others facing similar issues. We hope that our experiences helps the community advance towards more reliable and consistent deep learning practices.

Next Steps

We plan to expand our research in the following directions:

  • Collaborate with Framework Developers: Work with PyTorch and TensorFlow teams to integrate deterministic kernels.
  • Explore Reproducibility across Hardware Vendors: Focus on achieving reproducibility across a range of hardware platforms, including Nvidia, AMD, CPUs, and specialized hardware.
  • Reproducibility in Distributed Systems: Extend our work to multi-node setups, addressing the added challenges of reproducibility in distributed environments.
  • Apply Findings to ZKP Frameworks: Collaborate with developers in the ZKP space to ensure our solutions meet the stringent requirements of zero-knowledge proofs.

Please feel free to reach out with any questions or if you’re interested in collaborating on improving reproducibility in deep learning and zero-knowledge proofs.

Follow Ingonyama

Twitter: https://twitter.com/Ingo_zk

YouTube: https://www.youtube.com/@ingo_zk

GitHub: https://github.com/ingonyama-zk

LinkedIn: https://www.linkedin.com/company/ingonyama

Join us: https://www.ingonyama.com/career

light

Written by

Table of Contents

Want to discuss further?

Ingonyama is commited to developing hardware for a private future using Zero Knowledge Proofs.

Get in touch
Get our RSS feed